easydel.trainers.utils

Contents

easydel.trainers.utils#

Utility functions and classes for EasyDeL trainers.

This module provides essential utilities for training, including: - JAX distributed configuration management - Dataset creation and manipulation functions - Data collation utilities for various training tasks - Conversation formatting and prompt processing - Memory and performance profiling tools - Training state management utilities

class easydel.trainers.utils.BCODataCollatorGrain(max_prompt_length: int, max_completion_length: int, pad_token_id: int, label_pad_token_id: int, is_encoder_decoder: bool)[source]#

Bases: _BCODataCollatorMixin

Grain-compatible BCO data collator.

class easydel.trainers.utils.BCODataCollatorTFDS(max_prompt_length: int, max_completion_length: int, pad_token_id: int, label_pad_token_id: int, is_encoder_decoder: bool)[source]#

Bases: _BCODataCollatorMixin

Data collator for BCO training with TFDS backends.

class easydel.trainers.utils.CollateMapTransform(collate_fn: callable)[source]#

Bases: MapTransform

Grain transform for applying custom collation functions.

Wraps a user-defined collation function as a Grain MapTransform, allowing custom batch processing logic in the Grain pipeline.

collate_fn#

Callable that processes/collates data elements.

Type

callable

collate_fn: callable#
map(element)[source]#

Apply the collation function to an element.

Parameters

element – Input data element to collate.

Returns

Collated/processed element.

class easydel.trainers.utils.DPODataCollatorWithPaddingGrain(max_prompt_length: int, max_completion_length: int, pad_token_id: int = 0, label_pad_token_id: int = -100, is_encoder_decoder: bool | None = False, output_arrays_only: bool = True, prepadded: bool = True)[source]#

Bases: object

DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

is_encoder_decoder: bool | None = False#
label_pad_token_id: int = -100#
max_completion_length: int#
max_prompt_length: int#
output_arrays_only: bool = True#
pad_token_id: int = 0#
prepadded: bool = True#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.trainers.utils.DPODataCollatorWithPaddingTFDS(max_prompt_length: int, max_completion_length: int, pad_token_id: int = 0, label_pad_token_id: int = -100, is_encoder_decoder: bool | None = False, output_arrays_only: bool = True, prepadded: bool = True)[source]#

Bases: object

Advanced data collator for DPO training with TFDS.

Extended version of DataCollatorForPreferenceTFDS with additional features for handling complex DPO scenarios including encoder-decoder models and pre-padded data.

max_prompt_length#

Maximum length for prompt sequences.

Type

int

max_completion_length#

Maximum length for completion sequences.

Type

int

pad_token_id#

Token ID to use for padding (default 0).

Type

int

label_pad_token_id#

Token ID for label padding (default -100).

Type

int

is_encoder_decoder#

Whether using encoder-decoder architecture.

Type

bool | None

output_arrays_only#

If True, only return array-type outputs.

Type

bool

prepadded#

If True, assumes inputs are already padded.

Type

bool

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

is_encoder_decoder: bool | None = False#
label_pad_token_id: int = -100#
max_completion_length: int#
max_prompt_length: int#
output_arrays_only: bool = True#
pad_token_id: int = 0#
prepadded: bool = True#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.trainers.utils.DataCollatorForCompletionOnlyLM(processing_class: Union[str, PreTrainedTokenizerBase], response_template: str | list[int], instruction_template: str | list[int] | None = None, *args, mlm: bool = False, ignore_index: int = -100, **kwargs)[source]#

Bases: object

Data collator for training on assistant completions only.

This collator masks out non-assistant tokens in the labels, ensuring that the loss is only calculated on the model’s completions (assistant responses) and not on the user prompts or system messages.

This is particularly useful for: - Instruction tuning where you only want to train on responses - Chat models where user inputs should not contribute to loss - Maintaining the model’s ability to understand prompts without

being trained to generate them

processing_class#

Tokenizer or processor for encoding text.

response_template#

Template or token IDs marking response start.

instruction_template#

Optional template marking instruction start.

mlm#

Whether using masked language modeling (default False).

ignore_index#

Label value to ignore in loss calculation.

jax_call(examples: list[list[int] | Any | dict[str, Any]]) dict[str, Any][source]#
jax_mask_tokens(inputs: Any, special_tokens_mask: Any | None = None) tuple[Any, Any][source]#

Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.

class easydel.trainers.utils.DataCollatorForPreferenceGrain(max_prompt_length: int, max_completion_length: int, pad_token_id: int = 0, label_pad_token_id: int = -100, is_encoder_decoder: bool | None = False)[source]#

Bases: object

Data collator for Direct Preference Optimization (DPO) with Grain.

Grain-compatible version of DataCollatorForPreferenceTFDS. Processes single dictionaries instead of lists for Grain’s data pipeline.

max_prompt_length#

Maximum length for prompt sequences.

Type

int

max_completion_length#

Maximum length for completion sequences.

Type

int

pad_token_id#

Token ID to use for padding (default 0).

Type

int

label_pad_token_id#

Token ID for label padding (default -100).

Type

int

is_encoder_decoder#

Whether using encoder-decoder architecture.

Type

bool | None

Note

Returns NumPy arrays for Grain compatibility. Handles single feature dictionary rather than list.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

is_encoder_decoder: bool | None = False#
label_pad_token_id: int = -100#
max_completion_length: int#
max_prompt_length: int#
pad_token_id: int = 0#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.trainers.utils.DataCollatorForPreferenceTFDS(max_prompt_length: int, max_completion_length: int, pad_token_id: int = 0, label_pad_token_id: int = -100, is_encoder_decoder: bool | None = False)[source]#

Bases: object

Data collator for Direct Preference Optimization (DPO) with TFDS.

Handles batching and padding of prompt-completion pairs for preference learning. Each example has a prompt, chosen completion, and rejected completion that need to be padded separately.

max_prompt_length#

Maximum length for prompt sequences.

Type

int

max_completion_length#

Maximum length for completion sequences.

Type

int

pad_token_id#

Token ID to use for padding (default 0).

Type

int

label_pad_token_id#

Token ID for label padding (default -100).

Type

int

is_encoder_decoder#

Whether using encoder-decoder architecture.

Type

bool | None

Note

Supports multimodal inputs with pixel_values and pixel_attention_mask. Can include reference model log probabilities if provided.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

is_encoder_decoder: bool | None = False#
label_pad_token_id: int = -100#
max_completion_length: int#
max_prompt_length: int#
pad_token_id: int = 0#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.trainers.utils.GRPODataCollatorGrain(max_prompt_length: int, pad_token_id: int = 0)[source]#

Bases: object

Grain-compatible GRPO data collator.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

max_prompt_length: int#
pad_token_id: int = 0#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.trainers.utils.GRPODataCollatorTFDS(max_prompt_length: int, pad_token_id: int = 0)[source]#

Bases: object

Data collator for GRPO training with TFDS backends.

GRPO only needs prompts since completions are generated online.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

max_prompt_length: int#
pad_token_id: int = 0#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.trainers.utils.HFDataSource(dataset: IterableDataset, shard_options: ShardOptions, num_threads: int = 1)[source]#

Bases: RandomAccessDataSource

Grain-compatible data source for HuggingFace IterableDatasets.

Bridges HuggingFace’s IterableDataset with Google’s Grain data loading library, enabling efficient distributed data loading with proper sharding.

This class handles: - Multi-threaded data loading - Dataset sharding across distributed workers - Thread-safe iteration over dataset shards

dataset#

The HuggingFace IterableDataset to wrap.

shard_options#

Grain sharding configuration.

num_threads#

Number of worker threads for data loading.

Note

Automatically handles dataset sharding based on world size and rank. Issues warnings if dataset shards don’t match expected shard count.

class easydel.trainers.utils.JaxDistributedConfig[source]#

Bases: object

Configuration manager for JAX distributed training.

This class handles the initialization of JAX distributed computing environments, enabling multi-host and multi-device training setups. Originally from EasyLM project.

The class manages: - Multi-process coordination - Device assignment - Communication setup between processes

Note

This is typically used internally by TrainingArguments and should not need to be configured directly by users in most cases.

static get_default_config(updates=None)[source]#

Get default configuration for JAX distributed.

Parameters

updates – Optional dictionary of configuration updates to apply to the default configuration.

Returns

Configuration dictionary with the following fields:
  • initialize_jax_distributed: Whether to initialize distributed

  • coordinator_address: Address of the coordinator process

  • num_processes: Total number of processes

  • process_id: ID of the current process

  • local_device_ids: Comma-separated list of local device IDs

Return type

ConfigDict

Note

Uses ml_collections placeholders for required fields that must be provided at runtime.

classmethod initialize(config=None)[source]#

Initialize JAX distributed with the given configuration.

Parameters

config – Configuration dictionary or None to use defaults. If provided, should contain distributed setup parameters.

Note

Only initializes if config.initialize_jax_distributed is True. Parses local_device_ids from comma-separated string if provided.

Raises

RuntimeError – If JAX distributed initialization fails.

class easydel.trainers.utils.RewardDataCollatorWithPaddingGrain(tokenizer: Any, padding: bool | str = 'max_length', max_length: int | None = None, truncation_mode: str = 'keep_end')[source]#

Bases: object

Data collator for reward modeling with Grain data loading.

Similar to RewardDataCollatorWithPaddingTFDS but designed for use with Google’s Grain data loading library. Handles single dictionaries instead of lists of dictionaries.

tokenizer#

The tokenizer/processor for encoding text.

Type

Any

padding#

Padding strategy - ‘max_length’, True, or False.

Type

bool | str

max_length#

Maximum sequence length for padding.

Type

int | None

truncation_mode#

How to truncate sequences (‘keep_end’ or ‘keep_start’).

Type

str

Note

Returns NumPy arrays instead of JAX arrays for Grain compatibility. Expects a single feature dictionary rather than a list.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

max_length: int | None = None#
padding: bool | str = 'max_length'#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

tokenizer: Any#
truncation_mode: str = 'keep_end'#
class easydel.trainers.utils.RewardDataCollatorWithPaddingTFDS(tokenizer: Any, padding: bool | str = 'max_length', max_length: int | None = None, truncation_mode: str = 'keep_end')[source]#

Bases: object

Reward DataCollator class that pads the inputs to the maximum length of the batch.

Parameters
  • tokenizer (ProcessingClassType) – The tokenizer used for encoding the data.

  • padding (Union[bool, str, `PaddingStrategy]`, optional, defaults to True) – padding_strategy to pass to the tokenizer.

  • max_length (int or None, optional, defaults to None) – If set will pad the sequence to a maximum provided value.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

max_length: int | None = None#
padding: bool | str = 'max_length'#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

tokenizer: Any#
truncation_mode: str = 'keep_end'#
class easydel.trainers.utils.ToNumpy[source]#

Bases: MapTransform

Grain transform to convert data elements to NumPy arrays.

Ensures all values in a dictionary are converted to NumPy arrays, which is often required for JAX-based training pipelines.

map(element)[source]#

Convert all values in element to NumPy arrays.

Parameters

element – Dictionary with values to convert.

Returns

Same dictionary with all values as NumPy arrays.

Return type

dict

easydel.trainers.utils.add_bos_token_if_needed(bos_token_id: int | None, prompt_len_input_ids: int, prompt_tokens: dict[str, list[int]], chosen_prompt_len_input_ids: int, chosen_tokens: dict[str, list[int]], rejected_prompt_len_input_ids: int, rejected_tokens: dict[str, list[int]])[source]#

Add beginning-of-sequence token to prompts if needed.

Ensures all prompt sequences start with BOS token for consistency in preference learning scenarios.

Parameters
  • bos_token_id – BOS token ID, or None if not used.

  • prompt_len_input_ids – Length of main prompt.

  • prompt_tokens – Main prompt token dictionary.

  • chosen_prompt_len_input_ids – Length of chosen prompt.

  • chosen_tokens – Chosen response token dictionary.

  • rejected_prompt_len_input_ids – Length of rejected prompt.

  • rejected_tokens – Rejected response token dictionary.

Returns

(prompt_tokens, chosen_tokens, rejected_tokens) with BOS added.

Return type

tuple

Note

Only adds BOS if it’s not already present at the beginning. Updates both input_ids and attention_mask.

easydel.trainers.utils.add_eos_token_if_needed(eos_token_id: int, chosen_tokens: dict[str, list[int]], rejected_tokens: dict[str, list[int]])[source]#

Add end-of-sequence token to responses if needed.

Ensures both chosen and rejected responses end with EOS token for proper sequence termination.

Parameters
  • eos_token_id – EOS token ID to add.

  • chosen_tokens – Chosen response token dictionary.

  • rejected_tokens – Rejected response token dictionary.

Returns

(chosen_tokens, rejected_tokens) with EOS added.

Return type

tuple

Note

Only adds EOS if it’s not already present at the end. Updates both input_ids and attention_mask.

easydel.trainers.utils.conversations_formatting_function(processing_class: AutoTokenizer, messages_field: Literal['messages', 'conversations'], tools: list | None = None)[source]#

Create a formatter for conversation/chat datasets.

Returns a function that applies chat templates to conversation data, converting structured conversations into formatted text suitable for training chat models.

Parameters
  • processing_class – Tokenizer with chat template support.

  • messages_field – Field name containing conversations - either ‘messages’ or ‘conversations’.

  • tools – Optional list of tools for function calling support.

Returns

Function that formats dataset examples using the

tokenizer’s chat template.

Return type

Callable

Note

Handles both single conversations and batches of conversations. The returned function expects datasets with the specified messages_field containing role-based conversation data.

easydel.trainers.utils.create_constant_length_dataset(processing_class, dataset, dataset_text_field: str | None = None, formatting_func: Optional[Callable] = None, infinite: bool = False, seq_length: int = 1024, num_of_sequences: int = 1024, chars_per_token: float = 3.6, eos_token_id: int = 0, shuffle: bool = True, append_concat_token: bool = True, add_special_tokens: bool = True) Callable[[], Iterator[dict[str, jax.Array]]][source]#

Creates a generator function that yields constant length chunks of tokens from a stream of text files.

Parameters
  • processing_class – The processor used for processing the data.

  • dataset – Dataset with text files.

  • dataset_text_field – Name of the field in the dataset that contains the text.

  • formatting_func – Function that formats the text before tokenization.

  • infinite – If True the iterator is reset after dataset reaches end else stops.

  • seq_length – Length of token sequences to return.

  • num_of_sequences – Number of token sequences to keep in buffer.

  • chars_per_token – Number of characters per token used to estimate number of tokens in text buffer.

  • eos_token_id – Id of the end of sequence token if the passed processing_class does not have an EOS token.

  • shuffle – Shuffle the examples before they are returned.

  • append_concat_token – If true, appends eos_token_id at the end of each sample being packed.

  • add_special_tokens – If true, processing_class adds special tokens to each sample being packed.

Returns

A generator function that yields dictionaries containing input_ids and attention_mask as jnp.arrays

easydel.trainers.utils.create_prompt_creator(processing_class)[source]#

Create a prompt formatting function for conversation data.

Parameters

processing_class – Tokenizer or processor class used for formatting.

Returns

A function that formats conversation samples into prompts

suitable for training.

Return type

Callable

Note

The returned function expects samples with a ‘conversation’ field containing input/output pairs and formats them using the conversations_formatting_function.

easydel.trainers.utils.first_true_indices(bools, dtype=<class 'jax.numpy.int32'>)[source]#

Find the index of the first True value along the last axis.

Takes an N-dimensional bool array and returns an (N-1)-dimensional array of integers giving the position of the first True in each “row”.

Parameters
  • bools – N-dimensional boolean array to search.

  • dtype – Data type for the output indices (default jnp.int32).

Returns

(N-1)-dimensional array of indices where each element

is the position of the first True in the corresponding row. Returns row_len if no True values found.

Return type

jnp.ndarray

Note

Uses a clever trick with minimum to find first True efficiently. Returns the row length if no True value is found in a row.

easydel.trainers.utils.get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], processing_class: AutoTokenizer, tools: list | None = None) Optional[Callable][source]#

Automatically detect and return appropriate formatting function.

Examines dataset structure to determine the appropriate formatting function (chat format or instruction format) based on field names and schemas.

Parameters
  • dataset – HuggingFace Dataset to analyze.

  • processing_class – Tokenizer to use for formatting.

  • tools – Optional tools for function calling support.

Returns

Appropriate formatting function, or None if

no suitable format is detected.

Return type

Callable | None

Note

Supports: - ChatML format (messages/conversations fields) - Instruction format (prompt/completion fields) Returns None if dataset doesn’t match known formats.

easydel.trainers.utils.instructions_formatting_function(processing_class: AutoTokenizer)[source]#

Create a formatter for instruction-following datasets.

Returns a function that converts prompt-completion pairs into chat format using the tokenizer’s chat template. Originally from TRL.

Parameters

processing_class – Tokenizer with chat template support.

Returns

Function that formats instruction datasets by converting

prompt/completion pairs to user/assistant conversations.

Return type

Callable

Note

Expects datasets with ‘prompt’ and ‘completion’ fields. Automatically converts to chat format with user/assistant roles.

easydel.trainers.utils.leave_alone_context_manager()[source]#

No-op context manager that does nothing.

Useful as a placeholder when a context manager is required but no actual context management is needed.

Yields

None – Simply yields control back to the caller.

easydel.trainers.utils.np_pad(tensors: list[numpy.ndarray], max_lenght: int | None, padding_value: int = 0, padding_side: str = 'right') ndarray[source]#

Pad a list of NumPy tensors to uniform shape.

Similar to pad() but for NumPy arrays instead of JAX arrays.

Parameters
  • tensors – List of NumPy arrays to pad.

  • max_lenght – Target length for padding. If None, uses current max.

  • padding_value – Value to use for padding (default 0).

  • padding_side – Where to add padding - ‘left’ or ‘right’ (default ‘right’).

Returns

Batched and padded array.

Return type

np.ndarray

Raises

ValueError – If padding_side is not ‘left’ or ‘right’.

easydel.trainers.utils.pad(tensors: list[jax.Array], max_lenght: int | None, padding_value: int = 0, padding_side: str = 'right') Array[source]#

Pad a list of JAX tensors to uniform shape.

Parameters
  • tensors – List of JAX arrays to pad.

  • max_lenght – Target length for padding. If None, uses maximum length found in tensors.

  • padding_value – Value to use for padding (default 0).

  • padding_side – Where to add padding - ‘left’ or ‘right’ (default ‘right’).

Returns

Batched and padded tensor with shape

[batch_size, *tensor_shape, max_length].

Return type

jnp.ndarray

Raises

ValueError – If padding_side is not ‘left’ or ‘right’.

Note

Efficiently handles variable-length sequences by padding to a common length. Preserves dtype of input tensors.

easydel.trainers.utils.pad_sequence(sequences, batch_first=False, padding_value=0, max_len: int | None = None)[source]#

Pad a list of sequences to the same length.

Parameters
  • sequences – List of sequences (arrays) to pad.

  • batch_first – If True, output has batch dimension first. If False, adds padding to the left (default False).

  • padding_value – Value to use for padding (default 0).

  • max_len – Maximum length to pad to. If None, uses longest sequence.

Returns

Padded sequences as a single batched array.

Return type

jnp.ndarray

Note

Similar to PyTorch’s pad_sequence but for JAX arrays. When batch_first=False, padding is added to the left.

easydel.trainers.utils.pad_single(tensor: ndarray, max_length: int | None = None, padding_value: int = 0, padding_side: str = 'right') ndarray[source]#

Pad a single NumPy tensor along the last dimension.

Parameters
  • tensor – NumPy array to pad.

  • max_length – Target length for the last dimension. If None, returns tensor unchanged.

  • padding_value – Value to use for padding (default 0).

  • padding_side – Where to add padding - ‘left’ or ‘right’ (default ‘right’).

Returns

Padded tensor with last dimension of size max_length.

Return type

np.ndarray

Raises

ValueError – If padding_side is not ‘left’ or ‘right’.

Note

If tensor is already longer than max_length, it will be truncated from the appropriate side based on padding_side.

easydel.trainers.utils.pad_to_length(tensor: Union[Array, ndarray, bool, number], length: int, pad_value: int | float, axis: int = -1) Union[Array, ndarray, bool, number][source]#

Pad or truncate a tensor to a specific length along an axis.

Parameters
  • tensor – Input array to pad or truncate.

  • length – Target length for the specified axis.

  • pad_value – Value to use for padding.

  • axis – Axis along which to pad/truncate (default -1).

Returns

Tensor padded or truncated to the specified length.

Return type

chex.Array

Note

If tensor is already longer than length, it will be truncated. Special handling for 2D tensors.

easydel.trainers.utils.shift_and_pad(mask, *tensors)[source]#

Shift tensors to align with the first non-zero mask position.

Rolls each tensor so that the first ‘1’ in the mask appears at position 0. Useful for aligning sequences that have different starting positions.

Parameters
  • mask – Binary mask array indicating valid positions.

  • *tensors – Additional tensors to shift along with the mask.

Returns

Shifted mask and tensors (if provided), or just mask if no tensors.

Return type

tuple

Note

Modifies inputs in-place. Each row is shifted independently based on its first non-zero mask position.

easydel.trainers.utils.tolist(x)[source]#

Convert various array types to Python list.

Utility function from HuggingFace for consistent list conversion.

Parameters

x – Input to convert. Can be: - Python list (returned as-is) - NumPy array - JAX array - Tensor with .numpy() method

Returns

Python list representation of the input.

Return type

list

Note

Handles tensors by first converting to NumPy if they have a .numpy() method.

easydel.trainers.utils.truncate_right(input_ids, stop_token_id, pad_token_id)[source]#

Truncate sequences after the first occurrence of stop token.

Replaces all tokens after the first stop token with padding tokens and creates a corresponding attention mask.

Parameters
  • input_ids – 2D array of token IDs [batch_size, sequence_length].

  • stop_token_id – Token ID that marks where to stop.

  • pad_token_id – Token ID to use for padding truncated positions.

Returns

(output_ids, mask) where:
  • output_ids: Input with post-stop tokens replaced by padding

  • mask: Binary attention mask (1 for valid, 0 for padded)

Return type

tuple

Note

Useful for truncating generated sequences at EOS tokens. Preserves the stop token itself in the output.