easydel.__init__

Contents

easydel.__init__#

class easydel.__init__.ArcticConfig(vocab_size=32000, hidden_size=4096, intermediate_size=14336, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=4096, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=1000000.0, sliding_window=None, attention_dropout=0.0, num_experts_per_tok=1, num_local_experts=8, router_aux_loss_coef=0.001, moe_layer_frequency=2, parallel_attn_mlp_res=False, moe_train_capacity_factor=1, moe_eval_capacity_factor=1, enable_expert_tensor_parallelism=False, moe_min_capacity=0, moe_token_dropping=True, quantization=None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, rope_scaling: Dict[str, Union[str, float]] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the ARCTIC model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 14336) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 4096) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – The index of the padding token in the vocabulary. The default value (0) is the same as for GPT2.

  • bos_token_id (int, optional) – The index of the beginning of sequence token in the vocabulary. The default value (1) is the same as for GPT2.

  • eos_token_id (int, optional) – The index of the end of sequence token in the vocabulary. The default value (2) is the same as for GPT2.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 1e6) – The theta value to use for rotary position embeddings.

  • sliding_window (int, optional) – The sliding window size to use for attention. If not specified, no sliding window attention is used.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • num_experts_per_tok (int, optional, defaults to 1) – The number of experts per token for mixture of experts.

  • num_local_experts (int, optional, defaults to 8) – The number of local experts for mixture of experts.

  • router_aux_loss_coef (float, optional, defaults to 0.001) – The auxiliary loss coefficient for the router.

  • moe_layer_frequency (int, optional, defaults to 2) – The frequency of MoE layers.

  • parallel_attn_mlp_res (bool, optional, defaults to False) – Whether to parallelize attention and MLP residual connections.

  • moe_train_capacity_factor (float, optional, defaults to 1) – The capacity factor for MoE layers during training.

  • moe_eval_capacity_factor (float, optional, defaults to 1) – The capacity factor for MoE layers during evaluation.

  • enable_expert_tensor_parallelism (bool, optional, defaults to False) – Whether to enable expert tensor parallelism.

  • moe_min_capacity (int, optional, defaults to 0) – The minimum capacity for MoE layers.

  • moe_token_dropping (bool, optional, defaults to True) – Whether to drop tokens in MoE layers.

  • quantization (str, optional) – The quantization configuration.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use scan for MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size for scan MLP.

  • bits (int, optional) – The number of bits.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The rope scaling configuration.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#

Returns a tuple of parameter names for which weight decay should be excluded.

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size for frequency-based position embeddings, falling back to max_position_embeddings if freq_max_position_embeddings is not set.

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size for mask-based position embeddings, falling back to max_position_embeddings if mask_max_position_embeddings is not set.

model_type: str = 'arctic'#
static rng_keys()[source]#

Returns the names of the random number generator keys used by the model.

class easydel.__init__.ArcticForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Arctic model specifically adapted for Causal Language Modeling (CLM). This module wraps the core ArcticModel and adds a language modeling head on top.

config#

Configuration object for the Arctic model.

Type

ArcticConfig

dtype#

Data type for computation. Defaults to jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Defaults to jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Defaults to None.

Type

jax.lax.PrecisionLike

rngs#

Random number generators for the module.

Type

nn.Rngs

class easydel.__init__.ArcticModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Core Arctic model architecture. This module implements the main Transformer stack for the Arctic model, including token embeddings and decoder layers.

config#

Configuration object for the Arctic model.

Type

ArcticConfig

dtype#

Data type for computation. Defaults to jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Defaults to jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Defaults to None.

Type

jax.lax.PrecisionLike

rngs#

Random number generators for the module.

Type

nn.Rngs

class easydel.__init__.AttentionMechanisms(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enumeration of available attention mechanisms.

AUTO#

Automatically selects the best mechanism based on the backend.

FLASH_ATTN2#

FlashAttention-2 implementation.

RING#

RingAttention implementation.

VANILLA#

Standard dot-product attention.

SPLASH#

SplashAttention implementation (optimized for TPUs).

CUDNN#

cuDNN implementation (GPU specific).

BLOCKWISE#

Blockwise attention computation.

SDPA#

Scaled Dot Product Attention (potentially uses JAX native SDPA).

CUDA_FLASH_ATTN2#

CUDA specific FlashAttention-2 implementation.

PAGED_ATTENTION#

Paged attention for fast inference.

AUTO = 'auto'#
BLOCKWISE = 'blockwise'#
CUDA_FLASH_ATTN2 = 'cuda_flash_attn2'#
CUDNN = 'cudnn'#
FLASH_ATTN2 = 'flash_attn2'#
PAGED_ATTENTION = 'paged_attention'#
RING = 'ring'#
SDPA = 'sdpa'#
SPLASH = 'splash'#
VANILLA = 'vanilla'#
class easydel.__init__.AttentionMetadata(runtime_dtype: Union[str, type[Any], dtype, SupportsDType], runtime_softmax_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None, sequence_axis_name: str = Ellipsis, mesh: Optional[Mesh] = Ellipsis, platform: EasyDeLPlatforms = Ellipsis, backend: EasyDeLBackends = Ellipsis, partition_axis: PartitionAxis = Ellipsis, base_config: Optional[EasyDeLBaseConfig] = None, scan_ring_attention: bool = Ellipsis, softmax_scale: float = Ellipsis, dropout_prob: float = Ellipsis, blocksize_q: int = Ellipsis, blocksize_k: int = Ellipsis, blocksize_b: int = Ellipsis)[source]#

Bases: object

Holds configuration, context, and metadata for attention operations.

This class centralizes various parameters needed by different attention implementations, facilitating consistent behavior and configuration. It handles default values and can be initialized from an EasyDeLBaseConfig.

runtime_dtype#

The primary JAX dtype for computations (e.g., q, k, v).

Type

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]

runtime_softmax_dtype#

Optional JAX dtype for the softmax computation, allowing for higher precision if needed (e.g., float32).

Type

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]]

sequence_axis_name#

The name used for the sequence axis in JAX parallelism (axis_names for pjit).

Type

str

mesh#

The JAX device mesh for distributed computation. Must be provided or inferred from context.

Type

Optional[jax._src.mesh.Mesh]

platform#

The target hardware platform (e.g., TPU, GPU).

Type

easydel.infra.etils.EasyDeLPlatforms

backend#

The specific JAX backend being used (e.g., TPU, CUDA, ROCM).

Type

easydel.infra.etils.EasyDeLBackends

partition_axis#

Configuration for partitioning axes in distributed settings. (Likely from eformer.escale).

Type

eformer.escale.partition.constraints.PartitionAxis

base_config#

An optional reference to the base model configuration object for sourcing default values.

Type

Optional[easydel.infra.base_config.EasyDeLBaseConfig]

scan_ring_attention#

Boolean flag indicating whether to use ring attention via jax.lax.scan.

Type

bool

softmax_scale#

The scaling factor applied before the softmax operation. Often 1 / sqrt(head_dim).

Type

float

dropout_prob#

The dropout probability applied to attention weights.

Type

float

blocksize_q#

Block size for the query sequence dimension in blockwise attention.

Type

int

blocksize_k#

Block size for the key/value sequence dimension in blockwise attention.

Type

int

blocksize_b#

Block size for the batch dimension in blockwise attention (often 1).

Type

int

backend: EasyDeLBackends = Ellipsis#
base_config: Optional[EasyDeLBaseConfig] = None#
blocksize_b: int = Ellipsis#
blocksize_k: int = Ellipsis#
blocksize_q: int = Ellipsis#
dropout_prob: float = Ellipsis#
classmethod from_config(config: EasyDeLBaseConfig, softmax_scale: float, dropout_prob: float = 0.0) AttentionMetadata[source]#

Factory method to create AttentionMetadata from an EasyDeLBaseConfig.

Parameters
  • config – The base configuration object (e.g., model config).

  • softmax_scale – The attention softmax scaling factor. Usually calculated based on head dimension.

  • dropout_prob – The attention dropout probability. Defaults to 0.0.

Returns

An initialized AttentionMetadata instance.

classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

get_partition_specs(mode: RuntimeType, BTHD: bool = True)[source]#

Generates JAX PartitionSpecs for attention tensors based on runtime mode.

Parameters
  • mode – The current runtime mode (normal or generation).

  • BTHD – Boolean indicating tensor layout. True for (Batch, Time, Head, Dim), False for (Batch, Head, Time, Dim).

Returns

(query, key, value, bias, mask, attention_output)

Return type

A tuple containing PartitionSpecs for

mesh: Optional[Mesh] = Ellipsis#
partition_axis: PartitionAxis = Ellipsis#
platform: EasyDeLPlatforms = Ellipsis#
replace(**kwargs)#
runtime_dtype: Union[str, type[Any], dtype, SupportsDType]#
runtime_softmax_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None#
scan_ring_attention: bool = Ellipsis#
sequence_axis_name: str = Ellipsis#
set_attrs_carefully(attr_name: str, default: Optional[Any], pickup_name: Optional[str] = None, use_base_config: bool = True)[source]#

Internal helper to set an attribute if it’s not already set (or is Ellipsis).

Optionally retrieves the value from self.base_config using pickup_name (or attr_name if pickup_name is None).

Parameters
  • attr_name – The name of the attribute to set on self.

  • default – The default value to use if not found in base_config or if use_base_config is False.

  • pickup_name – The name of the attribute to look for in base_config. Defaults to attr_name.

  • use_base_config – Whether to attempt retrieving the value from base_config.

softmax_scale: float = Ellipsis#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

class easydel.__init__.AttentionModule(*args: Any, **kwargs: Any)[source]#

Bases: Module

Base class for Flax attention modules in EasyDeL, providing common utilities.

This class offers helper functions and attributes commonly needed by attention implementations within Flax, such as handling KV caching, sharding, mask manipulation, and head manipulation. Concrete attention implementations often inherit from this class.

config#

Configuration object for the attention module.

Type

SC | EasyDeLBaseConfig

cached_key#

Flax Cache for storing past key states (wont be used).

Type

nn.Cache[Array] | None

cached_value#

Flax Cache for storing past value states (wont be used).

Type

nn.Cache[Array] | None

cache_index#

Flax Cache for tracking the current index in the cache (wont be used).

Type

nn.Cache[Array] | None

static apply_complex_rotary(xq: Array, xk: Array, freqs_cis: Array) Tuple[Array, Array][source]#
static build_cache_pos(attention_mask: Array, cache_view: TransformerCacheView = None) Array[source]#

Calculates the position indices within the sequence for cache-aware operations.

Parameters
  • attention_mask (jax.Array) – The attention mask (typically [batch, heads, q_len, k_len]).

  • cache_view (TransformerCacheView, optional) – The current KV cache view. Defaults to None.

Returns

An array representing the position of each token in the sequence,

adjusted by the cache index if provided. Shape usually [batch, q_len].

Return type

jax.Array

concatenate(*, query: Union[Array, ndarray, bool, number], key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], attention_mask: Union[Array, ndarray, bool, number], cache_view: Optional[Union[TransformerCacheView, PagedAttentionCacheView]] = None, cache_metadata: Optional[Union[TransformerMetadata, PagedAttentionMetadata]] = None, causal_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None, fcm_mask: Optional[Union[Array, ndarray, bool, number]] = None, sliding_windows: Optional[int] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Callable[[], Union[Array, ndarray, bool, number]]][source]#

Prepares inputs for attention calculation, handling KV caching and mask merging.

This function combines the current query, key, and value with cached states (if applicable), merges various masks (attention, causal, FCM, sliding window), and returns the final key, value, attention mask, and a function to initialize the attention bias.

Parameters
  • query (Array) – Current query states [Batch, q_len, Heads, Dim].

  • key (Array) – Current key states [Batch, kv_len, Heads, Dim].

  • value (Array) – Current value states [Batch, kv_len, Heads, Dim].

  • attention_mask (Array) – Base attention mask (e.g., padding mask) [Batch, kv_len] or compatible.

  • cache_view (tp.Optional[TransformerCacheView], optional) – View into the KV cache. If None, caching is disabled. Defaults to None.

  • causal_mask (tp.Optional[Array], optional) – Causal mask [1, 1, q_len, kv_len]. Defaults to None.

  • token_type_ids (tp.Optional[Array], optional) – Token type IDs for segment masking [Batch, q_len]. Defaults to None.

  • fcm_mask (tp.Optional[Array], optional) – Fused-Context-Mask (specific use case) [Batch, 1, q_len, kv_len]. Defaults to None.

  • sliding_windows (tp.Optional[int], optional) – Size of the sliding attention window. If None, not applied. Defaults to None.

Returns

  • key_states (Array): Final key states (potentially from cache).

  • value_states (Array): Final value states (potentially from cache).

  • attention_mask (Array): The final combined attention mask [Batch, Heads, q_len, kv_len].

  • init_attention_bias (Callable): Function to create the attention bias tensor.

Return type

tp.Tuple[Array, Array, Array, tp.Callable[[], Array]]

property default_key_value_sharding#

Defines the default JAX sharding for key and value tensors.

Uses the partition specifications defined in the configuration’s partition_axis.

Returns

The default sharding configuration for K/V tensors.

Return type

NamedSharding

get_sharding_safely(tensor: Array) PartitionSpec[source]#

Retrieves the PartitionSpec of a tensor, falling back to the default KV sharding.

Parameters

tensor (jax.Array) – The tensor whose sharding spec is needed.

Returns

The sharding specification of the tensor.

Return type

PartitionSpec

make_flexible_sliding_window(attention_mask: Array, cache_view: TransformerCacheView, sliding_window: int)[source]#

Applies a sliding window mask to the attention mask, considering cache state.

Parameters
  • attention_mask (jax.Array) – The original attention mask.

  • cache_view (TransformerCacheView) – The current view of the KV cache.

  • sliding_window (int) – The size of the sliding window.

Returns

  • The attention mask combined with the sliding window mask.

  • A function (init_attention_bias) to create the corresponding attention bias.

Return type

tp.Tuple[jax.Array, tp.Callable[[], jax.Array]]

property quantizer#

Provides an EasyQuantizer instance based on the module’s configuration.

Used for quantizing KV cache entries if enabled in the config.

Returns

The quantizer instance.

Return type

EasyQuantizer

static repeat_key_value(key, value, num_reps: int)[source]#

Repeats key and value tensors for Grouped Query Attention (GQA).

Expands the head dimension by repeating num_reps times. Uses einops for concise repetition.

Parameters
  • key (Array) – Key tensor [Batch, Seq, NumKVHeads, Dim].

  • value (Array) – Value tensor [Batch, Seq, NumKVHeads, Dim].

  • num_reps (int) – The number of times to repeat each KV head (num_attention_heads / num_kv_heads).

Returns

Repeated key and value tensors, each with shape

[Batch, Seq, NumKVHeads * num_reps, Dim].

Return type

tp.Tuple[Array, Array]

shard_attention_prod(attn_output: Array) Array[source]#

Applies sharding constraints to the attention output tensor.

This is typically done before projecting the attention output back to the hidden dimension size.

Parameters

attn_output (jax.Array) – The output from the attention mechanism, usually with shape [Batch, SeqLen, NumHeads * DimPerHead].

Returns

The input tensor with applied sharding constraints based on the config.

Return type

jax.Array

class easydel.__init__.AttentionRegistry[source]#

Bases: object

Registry for discovering and managing different AttentionImpl classes.

Allows registering implementations using a decorator and retrieving or instantiating them by name.

classmethod create(impl_name: str, metadata: AttentionMetadata) AttentionImpl[source]#

Creates an instance of an attention implementation by name.

Retrieves the class associated with impl_name and initializes it with the provided metadata.

Parameters
  • impl_name – The name of the implementation to instantiate.

  • metadata – The AttentionMetadata to pass to the implementation’s constructor.

Returns

An initialized instance of the requested AttentionImpl subclass.

Raises

ValueError – If no implementation is registered with impl_name.

classmethod get(impl_name: str) Type[AttentionImpl][source]#

Retrieves an attention implementation class by its registered name.

Parameters

impl_name – The name of the implementation to retrieve.

Returns

The AttentionImpl subclass registered under the given name.

Raises

ValueError – If no implementation is registered with that name.

classmethod list_implementations() List[str][source]#

Returns a list of names of all registered attention implementations.

Returns

A list of strings, where each string is a registered implementation name.

classmethod register(impl_cls: Type[ICa]) Type[ICa][source]#

Class method decorator to register an AttentionImpl subclass.

The implementation is registered under the name(s) returned by its get_impl_name() class method.

Example: ```python @AttentionRegistry.register class FlashAttentionImpl(AttentionImpl):

@classmethod def get_impl_name(cls) -> str:

return “flash”

# … implementation …

```

Parameters

impl_cls – The AttentionImpl subclass to register.

Returns

The registered class itself.

class easydel.__init__.AutoEasyDeLConfig[source]#

Bases: object

classmethod from_pretrained(pretrained_model_name_or_path: str, sharding_axis_dims: Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: Optional[Sequence[int]] = None, sharding_axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: Optional[PartitionAxis] = None, shard_attention_computation: bool = True, backend: Optional[EasyDeLBackends] = None, platform: Optional[EasyDeLPlatforms] = None, model_task: TaskType = TaskType.CAUSAL_LM, from_torch: bool = False, **kwargs) EasyDeLBaseConfig[source]#

The from_pretrained function is a helper function that allows you to instantiate a model from the pretrained model repository. It takes as input the name of the model (e.g., ‘bert-base-uncased’) and returns an instance of the class corresponding to your model, with all weights loaded from disk.

Parameters
  • cls – Create an instance of the class that called this function.

  • pretrained_model_name_or_path – str: Identify the model in the huggingface model hub.

  • sharding_axis_dims – tp.Sequence[int]: Specify the dimension of each axis in the sharded model_tasking arrays in easydel.

  • shard_attention_computation – bool: whenever to use shard_map for attention.

  • backend – tp.Optional[EasyDeLBackends] : backend to use for model. model_task (TaskType): Task type of model load and find.

  • from_torch – should config be loaded from torch models or not.

  • **kwargs – Pass additional arguments to the model and config classes.

generation process

Returns

A Model Config

class easydel.__init__.AutoEasyDeLModel[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'base-module'#
class easydel.__init__.AutoEasyDeLModelForCausalLM[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

None#

Examples

>>> import jax
>>> from easydel import AutoEasyDeLModelForCausalLM
>>> # Load a GPT-2 model on a single CPU
>>> model = AutoEasyDeLModelForCausalLM.from_pretrained(
>>>   "gpt2", device=jax.devices("cpu")[0]
>>> )
>>> # Load a GPT-2 model sharded across 8 GPUs with data parallelism (DP) and fully sharded data parallelism (FSDP)
>>> model = AutoEasyDeLModelForCausalLM.from_pretrained(
...  "gpt2",
...  sharding_axis_dims=(1, 8, 1, 1),
...  sharding_axis_names=("dp", "fsdp", "tp", "sp"),
...  device=jax.devices("cpu")[0],  # offload to CPU [OPTIONAL]
...  from_torch=True,
>>> )
```
model_task: TaskType = 'causal-language-model'#
class easydel.__init__.AutoEasyDeLModelForImageTextToText[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'image-text-to-text'#
class easydel.__init__.AutoEasyDeLModelForSeq2SeqLM[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'sequence-to-sequence'#
class easydel.__init__.AutoEasyDeLModelForSequenceClassification[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'sequence-classification'#
class easydel.__init__.AutoEasyDeLModelForSpeechSeq2Seq[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

None#

Examples

>>> import jax
>>> from easydel import AutoEasyDeLModelForSpeechSeq2Seq
>>> # Load a openai/whisper-large-v3-turbo sharded
>>> model = AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained(
...  "openai/whisper-large-v3-turbo",
...  auto_shard_model=True,
>>> )
>>> # Load a openai/whisper-large-v3-turbo model sharded across 8 GPUs with data parallelism (DP) and fully sharded data parallelism (FSDP)
>>> model = AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained(
...  "openai/whisper-large-v3-turbo",
...  sharding_axis_dims=(1, 8, 1, 1),
...  sharding_axis_names=("dp", "fsdp", "tp", "sp"),
...  device=jax.devices("cpu")[0],  # offload to CPU [OPTIONAL]
...  from_torch=True,
>>> )
```
model_task: TaskType = 'speech-sequence-to-sequence'#
class easydel.__init__.AutoEasyDeLModelForZeroShotImageClassification[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'zero-shot-image-classification'#
class easydel.__init__.AutoEasyDeLVisionModel[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'vision-module'#
class easydel.__init__.AutoShardAndGatherFunctions[source]#

Bases: object

A class to automatically generate shard and gather functions for a given model configuration.

This class provides two methods to generate shard and gather functions:

  • from_config: Generates functions based on a provided EasyDeLBaseConfig object.

  • from_pretrained: Generates functions based on a pretrained model name or path.

None#
from_config()[source]#

Generates shard and gather functions based on a provided EasyDeLBaseConfig object.

from_pretrained()[source]#

Generates functions based on a pretrained model name or path.

classmethod from_config(config: EasyDeLBaseConfig, partition_rules: Optional[Tuple[Tuple[str, PartitionSpec]]] = None, flatten: bool = True, model_task: TaskType = TaskType.CAUSAL_LM, depth_target: Optional[List[str]] = None)[source]#

Generates shard and gather functions based on a provided EasyDeLBaseConfig object.

Parameters
  • config – An EasyDeLBaseConfig object containing the model configuration.

  • partition_rules – A tuple of tuples containing partition rule names and PartitionSpec objects. If None, uses the default partition rules from the config.

  • flatten – Whether to flatten the shard and gather functions. Defaults to True. model_task (TaskType): Task type of model load and find.

  • depth_target – Pad the sharding to depth, for example make {params:tensor} with depth_target = [“row”] to {row:{params:tensor}}. Defaults to None.

Returns

A tuple containing the shard and gather functions.

static from_params(params, partition_rules, mesh)[source]#

Generates shard and gather functions directly from model parameters, partition rules, and a mesh.

Parameters
  • params – The model parameters (pytree) to generate functions for.

  • partition_rules – A tuple of tuples defining the partitioning strategy.

  • mesh – The JAX device mesh to use for sharding.

Returns

A tuple containing the shard and gather functions.

classmethod from_pretrained(pretrained_model_name_or_path: str, sharding_axis_dims: Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: Optional[Sequence[int]] = None, sharding_axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: Optional[PartitionAxis] = None, shard_attention_computation: bool = True, backend: Optional[EasyDeLBackends] = None, platform: Optional[EasyDeLPlatforms] = None, partition_rules: Optional[Tuple[Tuple[str, PartitionSpec]]] = None, flatten: bool = True, config_kwargs: Optional[Mapping[str, Any]] = None, model_task: TaskType = TaskType.CAUSAL_LM, from_torch: bool = False, trust_remote_code: bool = False) Tuple[Mapping[str, Callable], Mapping[str, Callable]][source]#

Generates shard and gather functions based on a pretrained model name or path.

Parameters
  • pretrained_model_name_or_path – The name or path of the pretrained model.

  • sharding_axis_dims – The dimensions of the sharding axes. Defaults to (1, -1, 1, 1).

  • sharding_axis_names – The names of the sharding axes. Defaults to (“dp”, “fsdp”, “tp”, “sp”).

  • partition_axis (PartitionAxis) – PartitionAxis is new module used for partitioning arrays in easydel.

  • shard_attention_computation – Whether to shard the attention computation. Defaults to True.

  • backend – The backend to use for custom kernels. Defaults to None.

  • partition_rules – A tuple of tuples containing partition rule names and PartitionSpec objects. If None, uses the default partition rules from the config.

  • flatten – Whether to flatten the shard and gather functions. Defaults to True.

  • config_kwargs – Additional keyword arguments to pass to the AutoEasyDeLConfig constructor. Defaults to None. model_task (TaskType): Task type of model load and find. from_torch: should config be loaded from torch models or not.

  • trust_remote_code (bool) – whenever to trust remote code loaded from HF.

Returns

A tuple containing the shard and gather functions.

class easydel.__init__.AutoState[source]#

Bases: BaseAutoEasyState

class easydel.__init__.AutoStateForCausalLM[source]#

Bases: BaseAutoEasyState

class easydel.__init__.AutoStateForImageSequenceClassification[source]#

Bases: BaseAutoEasyState

class easydel.__init__.AutoStateForImageTextToText[source]#

Bases: BaseAutoEasyState

class easydel.__init__.AutoStateForSeq2SeqLM[source]#

Bases: BaseAutoEasyState

class easydel.__init__.AutoStateForSpeechSeq2Seq[source]#

Bases: BaseAutoEasyState

class easydel.__init__.AutoStateForZeroShotImageClassification[source]#

Bases: BaseAutoEasyState

class easydel.__init__.AutoStateVisionModel[source]#

Bases: BaseAutoEasyState

class easydel.__init__.AyaVisionConfig(vision_config=None, text_config=None, vision_feature_select_strategy='full', vision_feature_layer=-1, downsample_factor=2, adapter_layer_norm_eps=1e-06, image_token_index=255036, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the configuration of a [AyaVisionForConditionalGeneration]. It is used to instantiate an AyaVision model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of AyaVision. e.g. [CohereForAI/aya-vision-8b](https://huggingface.co/CohereForAI/aya-vision-8b)

Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.

Parameters
  • vision_config (Union[AutoConfig, dict], optional, defaults to CLIPVisionConfig) – The config object or dictionary of the vision backbone.

  • text_config (Union[AutoConfig, dict], optional, defaults to LlamaConfig) – The config object or dictionary of the text backbone.

  • vision_feature_select_strategy (str, optional, defaults to “full”) – The feature selection strategy used to select the vision feature from the vision backbone. Can be one of “default” or “full”. If “default”, the CLS token is removed from the vision features. If “full”, the full vision features are used.

  • vision_feature_layer (int, optional, defaults to -1) – The index of the layer to select the vision feature.

  • downsample_factor (int, optional, defaults to 2) – The downsample factor to apply to the vision features.

  • adapter_layer_norm_eps (float, optional, defaults to 1e-06) – The epsilon value used for layer normalization in the adapter.

  • image_token_index (int, optional, defaults to 255036) – The image token index to encode the image prompt.

get_partition_rules(*args, **kwargs)[source]#

Retrieves the combined partition rules from the text and vision configurations.

Parameters
  • *args – Positional arguments passed to the underlying config partition rule methods.

  • **kwargs – Keyword arguments passed to the underlying config partition rule methods.

Returns

Combined partition rules from both text and vision models.

Return type

Tuple

model_type: str = 'aya_vision'#
sub_configs: dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.auto.auto_configuration.AutoEasyDeLConfig'>, 'vision_config': <class 'easydel.__init__.modules.auto.auto_configuration.AutoEasyDeLConfig'>}#
class easydel.__init__.AyaVisionForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

AyaVision model for conditional text generation based on image inputs. Combines a vision tower and a language model with a multi-modal projector.

config#

Configuration object.

Type

AyaVisionConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

get_image_features(pixel_values: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#

Extracts and projects image features from the vision tower.

Parameters

pixel_values (chex.Array) – Input pixel values for the images.

Returns

Processed image features ready for the language model.

Return type

chex.Array

loss_type = 'ForCausalLM'#
prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Prepares inputs for text generation, including pixel values if provided.

Parameters
  • input_ids (chex.Array) – Initial input token IDs.

  • max_length (int) – Maximum generation length.

  • pixel_values (Optional[chex.Array]) – Pixel values for image input.

  • attention_mask (Optional[chex.Array]) – Attention mask.

Returns

Model inputs ready for generation.

Return type

dict

update_inputs_for_generation(model_outputs, model_kwargs)[source]#

Updates model inputs for the next step of generation, removing pixel values after the first step.

Parameters
  • model_outputs – Outputs from the previous generation step.

  • model_kwargs – Current keyword arguments for the model.

Returns

Updated model keyword arguments.

Return type

dict

class easydel.__init__.BaseTrainer(arguments: tp.Optional[TrainingArguments] = None, model_state: tp.Optional[EasyDeLState] = None, model: tp.type[EasyDeLBaseModule] = None, dataset_train: tp.Optional[Dataset] = None, dataset_eval: tp.Optional[Dataset] = None, data_collator: tp.Optional[tp.Callable] = None, finetune: bool = True, checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]] = None, **deprecated_kwargs)[source]#

Bases: BaseTrainerProtocol

apply_training_hooks(metrics: LossMetrics) LossMetrics[source]#

Apply training hooks to the model.

calculate_number_total_flops(params, is_training=True)[source]#
compile_aot() bool[source]#

Compile the state ahead of time for faster execution.

configure_dataloaders() TrainerConfigureDataloaderOutput[source]#

Configures the dataloaders for training and evaluation.

This method creates the training and evaluation dataloaders using the provided datasets and data collator. It also determines the maximum number of training and evaluation steps based on the dataset sizes and training arguments.

Returns

An object containing the configured dataloaders and the

maximum number of training and evaluation steps.

Return type

TrainerConfigureDataloaderOutput

abstract configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

configure_model() TrainerConfigureModelOutput[source]#

Configures the model, optimizer, scheduler, and configuration.

This method retrieves the model configuration from the model state, creates the optimizer and scheduler using the training arguments, and returns an object containing the configured model, optimizer, scheduler, and configuration.

Returns

An object containing the configured model, optimizer, scheduler, and configuration.

Return type

TrainerConfigureModelOutput

static count_model_parameters(prm)[source]#

Prints the number of model parameters in billions.

abstract create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#

Creates a function to collect and process batches of data for training or evaluation.

This function handles padding or truncating sequences to the specified max_sequence_length based on the chosen truncation_mode.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode. Defaults to “keep_end”.

Returns

A function that takes a batch of data and returns a processed batch.

Return type

tp.Callable

create_progress_bar(total: int, desc: str = '', disabled: bool = False) BaseProgressBar[source]#

Create a progress bar of the specified type.

property evaluation_batch_size#
static finish()[source]#

Finalize the training process.

get_runstage_flops(is_training) Union[float, Tuple[float, bool]][source]#

Return the total number of FLOPs for the model.

initialize_trainer_utils()[source]#

Initializes various utilities used by the trainer.

This includes setting up Weights & Biases, initializing the training timer, configuring dataloaders, configuring the model and optimizer, sharding the model and reference model states, and configuring the training and evaluation functions.

property is_process_zero#
log_metrics(metrics: Any, pbar: BaseProgressBar, step: int, mode: str = 'train')[source]#

Log metrics and update progress bar.

log_weight_distribution(state: EasyDeLState, step: int)[source]#

Log distribution of weights.

property mesh#
property model#
on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#

hook process to call in start of the step.

on_step_start(state: EasyDeLState, step: int) EasyDeLState[source]#

hook process to call in start of the step.

save_information(output_path: Union[str, Path]) None[source]#

Save the generated information to a markdown file.

Parameters

output_path – Path where the markdown file should be saved

save_pretrained(state: EasyDeLState, save_directory: Optional[str] = None, gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[Callable]]] = None, to_torch: bool = False, easystate_to_huggingface_model_kwargs: Optional[dict] = None, torch_save_pretrained_kwargs: Optional[dict] = None)[source]#

Saves the model state as a checkpoint file or to a Torch compatible directory.

specs_to_name_sharding(tree, mesh=None)[source]#

Convert specs to named sharding.

start_evaluation_hook()[source]#

Hook to run before evaluation starts.

start_training_hook()[source]#

Hook to run before training starts.

property training_batch_size#
class easydel.__init__.CLIPConfig(text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs)[source]#

Bases: EasyDeLBaseConfig

[CLIPConfig] is the configuration class to store the configuration of a [CLIPModel]. It is used to instantiate a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the CLIP [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • text_config (dict, optional) – Dictionary of configuration options used to initialize [CLIPTextConfig].

  • vision_config (dict, optional) – Dictionary of configuration options used to initialize [CLIPVisionConfig].

  • projection_dim (int, optional, defaults to 512) – Dimensionality of text and vision projection layers.

  • logit_scale_init_value (float, optional, defaults to 2.6592) – The initial value of the logit_scale parameter. Default is used as per the original CLIP implementation.

  • kwargs (optional) – Dictionary of keyword arguments.

Example:

```python >>> from transformers import CLIPConfig, CLIPModel

>>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration
>>> configuration = CLIPConfig()
>>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
>>> model = CLIPModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
>>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig
>>> from transformers import CLIPTextConfig, CLIPVisionConfig
>>> # Initializing a CLIPText and CLIPVision configuration
>>> config_text = CLIPTextConfig()
>>> config_vision = CLIPVisionConfig()
>>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision)
```
classmethod from_text_vision_configs(text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs)[source]#

Instantiate a [CLIPConfig] (or a derived class) from clip text model configuration and clip vision model configuration.

Parameters
  • text_config (CLIPTextConfig) – The text model configuration.

  • vision_config (CLIPVisionConfig) – The vision model configuration.

  • **kwargs – Additional keyword arguments.

Returns

An instance of a configuration object

Return type

[CLIPConfig]

get_partition_rules(*arg, **kwargs)#

Generic partition rules for CLIP text and vision models.

Parameters
  • self – The configuration object (unused but part of method signature).

  • *arg – Additional positional arguments (unused).

  • **kwargs – Additional keyword arguments (unused).

Returns

A tuple of partition rules for model parameters.

Return type

Tuple

model_type: str = 'clip'#
sub_configs: dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.clip.clip_configuration.CLIPTextConfig'>, 'vision_config': <class 'easydel.__init__.modules.clip.clip_configuration.CLIPVisionConfig'>}#
class easydel.__init__.CLIPForImageClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

CLIP vision model with an image classification head on top (a linear layer on the pooled final hidden state).

config#

Configuration object.

Type

CLIPVisionConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.__init__.CLIPModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

compute_loss(*, labels=None, loss_config=None, loss_kwargs=None, **batch) Tuple[Any, CLIPOutput][source]#

Computes the loss for the model given a batch of inputs and labels.

This method performs a forward pass using the provided batch arguments, then calculates the loss using the determined loss_function. It handles potential label inference (e.g., using input_ids as labels for Causal LM) and default loss configurations.

Parameters
  • labels (tp.Optional[chex.Array], optional) – The target labels. If None and the task is Causal LM, input_ids from the batch might be used. Defaults to None.

  • loss_config (tp.Optional[LossConfig], optional) – Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.

  • loss_kwargs (tp.Optional[tp.Dict], optional) – Additional keyword arguments to pass directly to the loss function. Defaults to None.

  • **batch – Keyword arguments representing the input batch (e.g., input_ids, attention_mask).

Returns

A tuple containing:
  • The model’s output ( Pytree typically including logits, hidden states etc.)

  • A LossMetrics object containing the calculated loss and potentially other metrics.

Return type

tp.Tuple[tp.Any, LossMetrics]

Raises
  • AssertionError – If labels are required for the loss function but are not provided or inferred.

  • AssertionError – If sequence classification loss is used without num_labels in the config.

get_image_features(pixel_values: Union[Array, ndarray, bool, number])[source]#
get_text_features(input_ids: Union[Array, ndarray, bool, number], attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, position_ids: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
class easydel.__init__.CLIPTextConfig(vocab_size=49408, hidden_size=512, intermediate_size=2048, projection_dim=512, num_hidden_layers=12, num_attention_heads=8, max_position_embeddings=77, hidden_act='quick_gelu', layer_norm_eps=1e-05, attention_dropout=0.0, initializer_range=0.02, initializer_factor=1.0, pad_token_id=1, bos_token_id=49406, eos_token_id=49407, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the configuration of a [CLIPTextModel]. It is used to instantiate a CLIP text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the text encoder of the CLIP [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 49408) – Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [CLIPModel].

  • hidden_size (int, optional, defaults to 512) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 2048) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • projection_dim (int, optional, defaults to 512) – Dimensionality of text and vision projection layers.

  • num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 8) – Number of attention heads for each attention layer in the Transformer encoder.

  • max_position_embeddings (int, optional, defaults to 77) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).

  • hidden_act (str or function, optional, defaults to “quick_gelu”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “selu” and “gelu_new” “quick_gelu” are supported.

  • layer_norm_eps (float, optional, defaults to 1e-05) – The epsilon used by the layer normalization layers.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • initializer_factor (float, optional, defaults to 1.0) – A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing).

  • pad_token_id (int, optional, defaults to 1) – Padding token id.

  • bos_token_id (int, optional, defaults to 49406) – Beginning of stream token id.

  • eos_token_id (int, optional, defaults to 49407) – End of stream token id.

Example:

```python >>> from transformers import CLIPTextConfig, CLIPTextModel

>>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration
>>> configuration = CLIPTextConfig()
>>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
>>> model = CLIPTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
base_config_key: str = 'text_config'#
get_partition_rules(*arg, **kwargs)#

Generic partition rules for CLIP text and vision models.

Parameters
  • self – The configuration object (unused but part of method signature).

  • *arg – Additional positional arguments (unused).

  • **kwargs – Additional keyword arguments (unused).

Returns

A tuple of partition rules for model parameters.

Return type

Tuple

model_type: str = 'clip_text_model'#
class easydel.__init__.CLIPTextModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Bare CLIP text model (transformer) outputting raw hidden-states without any specific head on top.

config#

Configuration object.

Type

CLIPTextConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.__init__.CLIPTextModelWithProjection(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

CLIP text model with a projection layer on top.

config#

Configuration object.

Type

CLIPTextConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.__init__.CLIPVisionConfig(hidden_size=768, intermediate_size=3072, projection_dim=512, num_hidden_layers=12, num_attention_heads=12, num_channels=3, image_size=224, patch_size=32, hidden_act='quick_gelu', layer_norm_eps=1e-05, attention_dropout=0.0, initializer_range=0.02, initializer_factor=1.0, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the configuration of a [CLIPVisionModel]. It is used to instantiate a CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • projection_dim (int, optional, defaults to 512) – Dimensionality of text and vision projection layers.

  • num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_channels (int, optional, defaults to 3) – The number of input channels.

  • image_size (int, optional, defaults to 224) – The size (resolution) of each image.

  • patch_size (int, optional, defaults to 32) – The size (resolution) of each patch.

  • hidden_act (str or function, optional, defaults to “quick_gelu”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “selu” and “gelu_new” “quick_gelu” are supported.

  • layer_norm_eps (float, optional, defaults to 1e-05) – The epsilon used by the layer normalization layers.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • initializer_factor (float, optional, defaults to 1.0) – A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing).

Example:

```python >>> from transformers import CLIPVisionConfig, CLIPVisionModel

>>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration
>>> configuration = CLIPVisionConfig()
>>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
>>> model = CLIPVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
base_config_key: str = 'vision_config'#
get_partition_rules(*arg, **kwargs)#

Generic partition rules for CLIP text and vision models.

Parameters
  • self – The configuration object (unused but part of method signature).

  • *arg – Additional positional arguments (unused).

  • **kwargs – Additional keyword arguments (unused).

Returns

A tuple of partition rules for model parameters.

Return type

Tuple

model_type: str = 'clip_vision_model'#
class easydel.__init__.CLIPVisionModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Bare CLIP vision model (transformer) outputting raw hidden-states without any specific head on top.

config#

Configuration object.

Type

CLIPVisionConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.__init__.Cohere2Config(vocab_size=256000, hidden_size=8192, intermediate_size=22528, logit_scale=0.0625, num_hidden_layers=40, num_attention_heads=64, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=8192, initializer_range=0.02, layer_norm_eps=1e-05, use_cache=True, pad_token_id=0, bos_token_id=5, eos_token_id=255001, tie_word_embeddings=True, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, sliding_window=4096, sliding_window_pattern=4, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 256000) – Vocabulary size of the Cohere model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 8192) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 22528) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • logit_scale (float, optional, defaults to 0.0625) – A logit scale value used in the attention layer.

  • num_hidden_layers (int, optional, defaults to 40) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 64) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 8192) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 5) – The index of the beginning of sequence token in the vocabulary.

  • eos_token_id (int, optional, defaults to 255001) – The index of the end of sequence token in the vocabulary.

  • tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • attention_bias (bool, optional, defaults to False) – Whether to use attention bias.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • use_qk_norm (bool, optional, defaults to False) – Whether to use query and key normalization.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#

Returns a tuple of parameter names for which weight decay should be excluded.

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size for frequency-based position embeddings, falling back to max_position_embeddings.

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size for mask-based position embeddings, falling back to max_position_embeddings.

model_type: str = 'cohere'#
static rng_keys()[source]#

Returns the names of the random number generator keys used by the model.

class easydel.__init__.Cohere2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Cohere2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Cohere2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.CohereConfig(vocab_size=256000, hidden_size=8192, intermediate_size=22528, logit_scale=0.0625, num_hidden_layers=40, num_attention_heads=64, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=8192, initializer_range=0.02, layer_norm_eps=1e-05, use_cache=True, pad_token_id=0, bos_token_id=5, eos_token_id=255001, tie_word_embeddings=True, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, use_qk_norm: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 256000) – Vocabulary size of the Cohere model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 8192) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 22528) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • logit_scale (float, optional, defaults to 0.0625) – A logit scale value used in the attention layer.

  • num_hidden_layers (int, optional, defaults to 40) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 64) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 8192) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 5) – The index of the beginning of sequence token in the vocabulary.

  • eos_token_id (int, optional, defaults to 255001) – The index of the end of sequence token in the vocabulary.

  • tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • attention_bias (bool, optional, defaults to False) – Whether to use attention bias.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • use_qk_norm (bool, optional, defaults to False) – Whether to use query and key normalization.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • gradient_checkpointing (EasyDeLGradientCheckPointers) – Control the amount of memory used by jax

  • bits (Optional[int]) – Determine the number of bits used in the quantization

  • **kwargs – Additional keyword arguments.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#

Returns a tuple of parameter names for which weight decay should be excluded.

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size for frequency-based position embeddings, falling back to max_position_embeddings.

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size for mask-based position embeddings, falling back to max_position_embeddings.

model_type: str = 'cohere'#
static rng_keys()[source]#

Returns the names of the random number generator keys used by the model.

class easydel.__init__.CohereForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.CohereForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Cohere model for sequence classification.

config#

Configuration object (must include num_labels).

Type

CohereConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.__init__.CohereModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.ConfigType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enumeration defining types of configurations that can be registered.

MODULE_CONFIG#

Represents standard module configuration classes.

MODULE_CONFIG = 'module-config'#
class easydel.__init__.DPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 1e-06, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'DPOTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, beta: float = 0.1, label_smoothing: float = 0.0, loss_type: ~typing.Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = 'sigmoid', use_weighting: bool = False, label_pad_token_id: int = -100, padding_value: ~typing.Optional[int] = None, max_length: ~typing.Optional[int] = 512, max_prompt_length: ~typing.Optional[int] = 256, max_completion_length: ~typing.Optional[int] = None, is_encoder_decoder: ~typing.Optional[bool] = None, disable_dropout: bool = True, precompute_ref_log_probs: bool = False, dataset_num_proc: ~typing.Optional[int] = None, reference_free: bool = False, force_use_ref_model: bool = False, sync_ref_model: bool = False, ref_model_mixup_alpha: float = 0.9, ref_model_sync_steps: int = 64, rpo_alpha: ~typing.Optional[float] = None, tools: ~typing.Optional[~typing.List[~typing.Union[dict, ~typing.Callable]]] = None)[source]#

Bases: TrainingArguments

Configuration class for Direct Preference Optimization (DPO) training.

Inherits from TrainingArguments and adds parameters specific to DPO training as described in https://arxiv.org/abs/2305.18290. This configuration controls various aspects of the DPO training process including loss computation, model architecture, and dataset processing.

beta#

Temperature parameter (β) controlling deviation from reference model. Higher values make training focus more on preference matching. Default: 0.1

Type

float

label_smoothing#

Smoothing factor for labels in loss calculation. Helps prevent overconfidence. 0.0 means no smoothing. Default: 0.0

Type

float

loss_type#

Type of contrastive loss function to use. Valid options: ‘sigmoid’, ‘hinge’, ‘ipo’, ‘exo_pair’, ‘nca_pair’, ‘robust’, ‘bco_pair’, ‘sppo_hard’, ‘aot’, ‘aot_pair’, ‘apo_zero’, ‘apo_down’. Default: ‘sigmoid’

Type

LOSS_FN_VARIENTS

use_weighting#

Whether to apply example weighting in loss calculation. Default: False

Type

bool

label_pad_token_id#

Token ID used for padding labels. Default: -100

Type

int

padding_value#

Value used for padding sequences. If None, uses model’s default padding token. Default: None

Type

int | None

max_length#

Maximum total sequence length (prompt + completion). Default: 512

Type

int | None

max_prompt_length#

Maximum length for prompt sequences. Default: 256

Type

int | None

max_completion_length#

Maximum length for completion sequences. Auto-calculated as max_length - max_prompt_length if None. Default: None

Type

int | None

is_encoder_decoder#

Explicitly set if model is encoder-decoder. Auto-detected if None. Default: None

Type

bool | None

disable_dropout#

Whether to disable dropout during training for deterministic behavior. Default: True

Type

bool

precompute_ref_log_probs#

Whether to precompute reference model log probabilities before training. Default: False

Type

bool

dataset_num_proc#

Number of processes for dataset preprocessing. Default: None (sequential processing)

Type

int | None

reference_free#

Whether to use reference-free variant of DPO. Default: False

Type

bool

force_use_ref_model#

Force use reference model even when reference_free=True. Default: False

Type

bool

sync_ref_model#

Whether to periodically sync reference model with training model. Default: False

Type

bool

learning_rate#

Optimizer learning rate. Default: 1e-6

Type

float

ref_model_mixup_alpha#

Alpha parameter for mixup between policy and reference models. Default: 0.9

Type

float

ref_model_sync_steps#

Number of steps between reference model syncs. Default: 64

Type

int

rpo_alpha#

Alpha parameter for Relative Preference Optimization. None disables RPO. Default: None

Type

float | None

tools#

Additional tools for training process

Type

list[dict | Callable] | None

Example

>>> config = DPOConfig(
...   beta=0.2, loss_type="ipo", max_length=1024, learning_rate=5e-6
... )
beta: float = 0.1#
dataset_num_proc: Optional[int] = None#
disable_dropout: bool = True#
force_use_ref_model: bool = False#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

is_encoder_decoder: Optional[bool] = None#
label_pad_token_id: int = -100#
label_smoothing: float = 0.0#
learning_rate: float = 1e-06#
loss_type: Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = 'sigmoid'#
max_completion_length: Optional[int] = None#
max_length: Optional[int] = 512#
max_prompt_length: Optional[int] = 256#
model_name: str = 'DPOTrainer'#
padding_value: Optional[int] = None#
precompute_ref_log_probs: bool = False#
ref_model_mixup_alpha: float = 0.9#
ref_model_sync_steps: int = 64#
reference_free: bool = False#
replace(**kwargs)#
rpo_alpha: Optional[float] = None#
sync_ref_model: bool = False#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

tools: Optional[List[Union[dict, Callable]]] = None#
use_weighting: bool = False#
class easydel.__init__.DPOTrainer(arguments: DPOConfig, model: Union[EasyDeLBaseModule, EasyDeLState], reference_model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, processing_class: Optional[Any] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Any] = None, data_collator: Optional[Callable] = None)[source]#

Bases: Trainer

Trainer for Direct Preference Optimization (DPO).

This trainer handles the training, evaluation, and checkpointing of language models using the DPO algorithm. It supports sharding, gradient accumulation, mixed precision training, LoRA, and precomputed reference model log probabilities.

arguments: DPOConfig#
compute_reference_log_probs(state: EasyDeLState, padded_batch: Dict) tuple[Any, Any][source]#

Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.

Parameters
  • state (EasyDeLState) – The EasyDeLState object of the model (used if no reference model is provided).

  • padded_batch (tp.Dict) – The padded batch of data.

Returns

A tuple containing the log probabilities for the chosen and rejected responses.

Return type

tuple[tp.Any, tp.Any]

configure_dataloaders()[source]#

Returns the training dataloader, potentially with precomputed reference log probabilities.

If precompute_ref_log_probs is enabled, this method computes the reference model’s log probabilities for the chosen and rejected responses in the training dataset and adds them as columns to the dataset.

Returns

The training dataloader.

Return type

tensorflow.data.Dataset

configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Creates a data collection function for batching.

For DPO training, this method simply returns the pre-configured data_collator.

Parameters
  • max_sequence_length (int) – The maximum sequence length (not used in this implementation).

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.

Returns

The data collator function.

Return type

tp.Callable

on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#

hook process to call in start of the step.

static process_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#
static tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#

Tokenize a row of the dataset.

Parameters
  • features (dict[str, str]) – Row of the dataset, should contain the keys “prompt”, “chosen”, and “rejected”.

  • processing_class (PreTrainedTokenizerBase) – Processing class used to process the data.

  • max_prompt_length (int or None) – Maximum length of the prompt sequence. If None, the prompt sequence is not truncated.

  • max_completion_length (int or None) – Maximum length of the completion sequences. If None, the completion sequences are not truncated.

  • add_special_tokens (bool) – Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If True, the prompt sequence will have a bos token prepended and an eos token appended. In any case, the completion sequences will have an eos token appended.

Returns

Tokenized sequences with the keys “prompt_input_ids”, “chosen_input_ids”, and `”rejected_input_ids”.

Return type

dict[str, list[int]]

class easydel.__init__.DbrxAttentionConfig(attn_pdrop: float = 0, clip_qkv: Optional[float] = 8, kv_n_heads: int = 1, rope_theta: float = 10000.0, **kwargs: Any)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the attention related configuration of a [DbrxModel].

Parameters
  • attn_pdrop (float, optional, defaults to 0.0) – The dropout probability applied to the attention output.

  • clip_qkv (float, optional, defaults to 8.0) – The clip value applied to the query, key, and value tensors.

  • kv_n_heads (int, optional, defaults to 1) – The number of attention heads for the key and value tensors.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value for the rotary position embedding.

classmethod from_pretrained(pretrained_model_name_or_path: str, **kwargs: Any) PretrainedConfig[source]#

Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.

Parameters
  • pretrained_model_name_or_path (str or os.PathLike) –

    This can be either:

    • a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.

    • a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.

    • a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.

  • cache_dir (str or os.PathLike, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.

  • force_download (bool, optional, defaults to False) – Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.

  • resume_download – Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.

  • proxies (Dict[str, str], optional) – A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.

  • token (str or bool, optional) – The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).

  • revision (str, optional, defaults to “main”) –

    The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.

    <Tip>

    To test a pull request you made on the Hub, you can pass `revision=”refs/pr/<pr_number>”.

    </Tip>

  • return_unused_kwargs (bool, optional, defaults to False) –

    If False, then this function returns just the final configuration object.

    If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.

  • subfolder (str, optional, defaults to “”) – In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.

  • kwargs (Dict[str, tp.Any], optional) – The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.

Returns

The configuration object instantiated from this pretrained model.

Return type

[PretrainedConfig]

Examples:

>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
>>> # derived class: BertConfig
>>> config = BertConfig.from_pretrained(
...   "google-bert/bert-base-uncased"
>>> )  # Download configuration from huggingface.co and cache.
>>> config = BertConfig.from_pretrained(
...   "./test/saved_model/"
>>> )  # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
>>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
>>> config = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased", output_attentions=True, foo=False
>>> )
>>> assert config.output_attentions == True
>>> config, unused_kwargs = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased",
...  output_attentions=True,
...  foo=False,
...  return_unused_kwargs=True,
>>> )
>>> assert config.output_attentions == True
>>> assert unused_kwargs == {"foo": False}

```

class easydel.__init__.DbrxConfig(d_model: int = 2048, n_heads: int = 16, n_layers: int = 24, max_seq_len: int = 2048, vocab_size: int = 32000, resid_pdrop: float = 0.0, emb_pdrop: float = 0.0, attn_config: Optional[DbrxAttentionConfig] = None, ffn_config: Optional[DbrxFFNConfig] = None, use_cache: bool = True, initializer_range: float = 0.02, output_router_logits: bool = False, router_aux_loss_coef: float = 0.05, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs: Any)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • d_model (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.

  • n_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.

  • n_layers (int, optional, defaults to 24) – Number of hidden layers in the Transformer encoder.

  • max_seq_len (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the DBRX model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • emb_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • attn_config ([DbrxAttentionConfig], optional) – The configuration of the attention layer.

  • ffn_config ([DbrxFFNConfig], optional) – The configuration of the feed forward layer.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • output_router_logits (bool, optional, defaults to False) – Whether or not to output the router logits.

  • router_aux_loss_coef (float, optional, defaults to 0.05) – The coefficient of the router auxiliary loss.

attribute_map: dict[str, str] = {'hidden_size': 'd_model', 'max_position_embeddings': 'max_seq_len', 'num_attention_heads': 'n_heads', 'num_hidden_layers': 'n_layers'}#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model parameters.

These rules define how parameters should be sharded across devices when using model parallelism.

Parameters
  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

A tuple of partition rules for different parameter patterns.

Return type

Tuple

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size for frequency-based position embeddings.

Returns

The maximum position embedding size, falling back to max_seq_len if not explicitly set.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size for mask-based position embeddings.

Returns

The maximum position embedding size, falling back to max_seq_len if not explicitly set.

Return type

int

model_type: str = 'dbrx'#
class easydel.__init__.DbrxFFNConfig(ffn_act_fn: Optional[dict] = None, ffn_hidden_size: int = 3584, moe_num_experts: int = 4, moe_top_k: int = 1, moe_jitter_eps: Optional[float] = None, moe_loss_weight: float = 0.01, moe_normalize_expert_weights: Optional[float] = 1, uniform_expert_assignment: bool = False, **kwargs: Any)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the feed forward related configuration of a [DbrxModel].

Parameters
  • ffn_act_fn (dict, optional) – The activation function configuration for the feed-forward network.

  • ffn_hidden_size (int, optional, defaults to 3584) – The hidden size of the feed-forward network.

  • moe_num_experts (int, optional, defaults to 4) – The number of experts in the Mixture-of-Experts (MoE) layer.

  • moe_top_k (int, optional, defaults to 1) – The number of top experts to use in the MoE layer.

  • moe_jitter_eps (float, optional) – The jitter epsilon value for the MoE layer.

  • moe_loss_weight (float, optional, defaults to 0.01) – The loss weight for the MoE auxiliary loss.

  • moe_normalize_expert_weights (float, optional, defaults to 1.0) – The normalization factor for the expert weights in the MoE layer.

  • uniform_expert_assignment (bool, optional, defaults to False) – Whether to use uniform expert assignment in the MoE layer.

classmethod from_pretrained(pretrained_model_name_or_path: str, **kwargs: Any) EasyDeLBaseConfig[source]#

Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.

Parameters
  • pretrained_model_name_or_path (str or os.PathLike) –

    This can be either:

    • a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.

    • a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.

    • a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.

  • cache_dir (str or os.PathLike, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.

  • force_download (bool, optional, defaults to False) – Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.

  • resume_download – Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.

  • proxies (Dict[str, str], optional) – A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.

  • token (str or bool, optional) – The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).

  • revision (str, optional, defaults to “main”) –

    The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.

    <Tip>

    To test a pull request you made on the Hub, you can pass `revision=”refs/pr/<pr_number>”.

    </Tip>

  • return_unused_kwargs (bool, optional, defaults to False) –

    If False, then this function returns just the final configuration object.

    If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.

  • subfolder (str, optional, defaults to “”) – In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.

  • kwargs (Dict[str, tp.Any], optional) – The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.

Returns

The configuration object instantiated from this pretrained model.

Return type

[PretrainedConfig]

Examples:

>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
>>> # derived class: BertConfig
>>> config = BertConfig.from_pretrained(
...   "google-bert/bert-base-uncased"
>>> )  # Download configuration from huggingface.co and cache.
>>> config = BertConfig.from_pretrained(
...   "./test/saved_model/"
>>> )  # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
>>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
>>> config = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased", output_attentions=True, foo=False
>>> )
>>> assert config.output_attentions == True
>>> config, unused_kwargs = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased",
...  output_attentions=True,
...  foo=False,
...  return_unused_kwargs=True,
>>> )
>>> assert config.output_attentions == True
>>> assert unused_kwargs == {"foo": False}

```

class easydel.__init__.DbrxForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.DbrxForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.DbrxModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Base DBRX Model outputting raw hidden-states.

This model is a Transformer-based model with a mixture of experts (MoE) architecture, implementing the DBRX architecture as described in the original paper.

The model uses specialized attention modules and a router-based MoE FFN layer.

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

class easydel.__init__.DeepseekV2Config(vocab_size=102400, hidden_size=4096, intermediate_size=11008, moe_intermediate_size=1407, num_hidden_layers=30, num_attention_heads=32, num_key_value_heads=32, n_shared_experts=None, n_routed_experts=None, ep_size=1, routed_scaling_factor=1.0, kv_lora_rank=512, q_lora_rank=1536, qk_rope_head_dim=64, v_head_dim=128, qk_nope_head_dim=128, topk_method='gready', n_group=None, topk_group=None, num_experts_per_tok=None, moe_layer_freq=1, first_k_dense_replace=0, norm_topk_prob=False, scoring_func='softmax', aux_loss_alpha=0.001, seq_aux=True, hidden_act='silu', max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=None, bos_token_id=100000, eos_token_id=100001, pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, rope_scaling: Dict[str, Union[str, float]] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 102400) – Vocabulary size of the DeepseekV2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 11008) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • moe_intermediate_size (int, optional, defaults to 1407) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the MoE layer.

  • num_hidden_layers (int, optional, defaults to 30) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 32) – Number of key and value heads for each attention layer in the Transformer encoder.

  • n_shared_experts (int, optional) – Number of shared experts.

  • n_routed_experts (int, optional) – Number of routed experts.

  • ep_size (int, optional, defaults to 1) – Expert parallel size.

  • routed_scaling_factor (float, optional, defaults to 1.0) – Routed scaling factor.

  • kv_lora_rank (int, optional, defaults to 512) – KV LoRA rank.

  • q_lora_rank (int, optional, defaults to 1536) – Q LoRA rank.

  • qk_rope_head_dim (int, optional, defaults to 64) – QK rope head dimension.

  • v_head_dim (int, optional, defaults to 128) – V head dimension.

  • qk_nope_head_dim (int, optional, defaults to 128) – QK nope head dimension.

  • topk_method (str, optional, defaults to “gready”) – Top-k method.

  • n_group (int, optional) – Number of groups.

  • topk_group (int, optional) – Top-k group.

  • num_experts_per_tok (int, optional) – Number of experts per token.

  • moe_layer_freq (int, optional, defaults to 1) – MoE layer frequency.

  • first_k_dense_replace (int, optional, defaults to 0) – First k dense replace.

  • norm_topk_prob (bool, optional, defaults to False) – Whether to normalize top-k probabilities.

  • scoring_func (str, optional, defaults to “softmax”) – Scoring function.

  • aux_loss_alpha (float, optional, defaults to 0.001) – Auxiliary loss alpha.

  • seq_aux (bool, optional, defaults to True) – Whether to use sequence auxiliary loss.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 100000) – The index of the beginning of sequence token in the vocabulary.

  • eos_token_id (int, optional, defaults to 100001) – The index of the end of sequence token in the vocabulary.

  • pretraining_tp (int, optional, defaults to 1) – Pretraining TP.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • attention_bias (bool, optional, defaults to False) – Whether to use attention bias.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use scan for MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size for scan MLP.

  • bits (int, optional) – The number of bits to quantize the model to.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The rope scaling configuration.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#

Returns a tuple of parameter names for which weight decay should be excluded.

Returns

An empty tuple, indicating no weight decay exclusions.

Return type

tuple

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size for frequency-based position embeddings.

Returns

The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size for mask-based position embeddings.

Returns

The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.

Return type

int

model_type: str = 'deepseek_v2'#
static rng_keys()[source]#

Returns the names of the random number generator keys used by the model.

Returns

A tuple containing “params” and “dropout” as the RNG keys.

Return type

tuple

class easydel.__init__.DeepseekV2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

DeepseekV2 model with a language modeling head for causal language modeling tasks.

This model extends the base DeepseekV2Model by adding a linear language modeling head on top of the transformer model. It’s designed for generative tasks and can be used for text generation.

class easydel.__init__.DeepseekV2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

class easydel.__init__.DeepseekV3Config(vocab_size=129280, hidden_size=7168, intermediate_size=18432, moe_intermediate_size=2048, num_hidden_layers=61, num_nextn_predict_layers=1, num_attention_heads=128, num_key_value_heads=128, n_shared_experts=1, n_routed_experts=256, ep_size=1, routed_scaling_factor=2.5, kv_lora_rank=512, q_lora_rank=1536, qk_rope_head_dim=64, v_head_dim=128, qk_nope_head_dim=128, topk_method='noaux_tc', n_group=8, topk_group=4, num_experts_per_tok=8, moe_layer_freq=1, first_k_dense_replace=3, norm_topk_prob=True, scoring_func='sigmoid', aux_loss_alpha=0.001, seq_aux=True, hidden_act='silu', max_position_embeddings=4096, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=None, bos_token_id=0, eos_token_id=1, pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the configuration of a [DeepseekV3Model]. It is used to instantiate an DeepSeek model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the DeepSeek-V3. Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information. :param vocab_size: Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the

inputs_ids passed when calling [DeepseekV3Model]

Parameters
  • hidden_size (int, optional, defaults to 4096) – Dimension of the hidden representations.

  • intermediate_size (int, optional, defaults to 11008) – Dimension of the MLP representations.

  • moe_intermediate_size (int, optional, defaults to 1407) – Dimension of the MoE representations.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer decoder.

  • num_nextn_predict_layers (int, optional, defaults to 1) – Number of nextn predict layers in the DeepSeekV3 Model.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer decoder.

  • n_shared_experts (int, optional, defaults to None) – Number of shared experts, None means dense model.

  • n_routed_experts (int, optional, defaults to None) – Number of routed experts, None means dense model.

  • routed_scaling_factor (float, optional, defaults to 1.0) – Scaling factor or routed experts.

  • topk_method (str, optional, defaults to gready) – Topk method used in routed gate.

  • n_group (int, optional, defaults to None) – Number of groups for routed experts.

  • topk_group (int, optional, defaults to None) – Number of selected groups for each token(for each token, ensuring the selected experts is only within topk_group groups).

  • num_experts_per_tok (int, optional, defaults to None) – Number of selected experts, None means dense model.

  • moe_layer_freq (int, optional, defaults to 1) – The frequency of the MoE layer: one expert layer for every moe_layer_freq - 1 dense layers.

  • first_k_dense_replace (int, optional, defaults to 0) –

    Number of dense layers in shallow layers(embed->dense->dense->…->dense->moe->moe…->lm_head).

    --k dense layers–/

  • norm_topk_prob (bool, optional, defaults to False) – Whether to normalize the weights of the routed experts.

  • scoring_func (str, optional, defaults to ‘softmax’) – Method of computing expert weights.

  • aux_loss_alpha (float, optional, defaults to 0.001) – Auxiliary loss weight coefficient.

  • = (seq_aux) – Whether to compute the auxiliary loss for each individual sample.

  • num_key_value_heads (int, optional) – This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) in the decoder.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-06) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – Padding token id.

  • bos_token_id (int, optional, defaults to 1) – Beginning of stream token id.

  • eos_token_id (int, optional, defaults to 2) – End of stream token id.

  • pretraining_tp (int, optional, defaults to 1) – Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is necessary to ensure exact reproducibility of the pretraining results. Please refer to [this issue](pytorch/pytorch#76232).

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie weight embeddings

  • rope_theta (float, optional, defaults to 10000.0) – The base period of the RoPE embeddings.

  • rope_scaling (Dict, optional) – Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is {“type”: strategy name, “factor”: scaling factor}. When using this flag, don’t update max_position_embeddings to the expected new maximum.

  • attention_bias (bool, defaults to False, optional, defaults to False) – Whether to use a bias in the query, key, value and output projection layers during self-attention.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

`python >>> from transformers import DeepseekV3Model, DeepseekV3Config >>> # Initializing a Deepseek-V3 style configuration >>> configuration = DeepseekV3Config() >>> # Accessing the model configuration >>> configuration = model.config `

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for model parameters.

These rules define how parameters should be sharded across devices when using model parallelism.

Parameters
  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns

A tuple of partition rules for different parameter patterns.

Return type

Tuple

keys_to_ignore_at_inference = ['past_key_values']#
model_type: str = 'deepseek_v3'#
class easydel.__init__.DeepseekV3ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

DeepseekV3 model with a language modeling head for causal language modeling tasks.

This model extends the base DeepseekV3Model by adding a linear language modeling head on top of the transformer model. It incorporates Mixture of Experts (MoE) architecture and is designed for generative tasks and text generation.

class easydel.__init__.DeepseekV3Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

class easydel.__init__.EasyDeLBackends(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enumeration of JAX backend types supported by EasyDeL.

Specifies the target hardware device type for JAX computations.

CPU#

Use the CPU backend.

GPU#

Use the GPU backend.

TPU#

Use the TPU backend.

CPU = 'cpu'#
GPU = 'gpu'#
TPU = 'tpu'#
class easydel.__init__.EasyDeLBaseConfig(axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), attn_mechanism: ~typing.Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa'] = 'vanilla', blocksize_k: int = 128, blocksize_q: int = 128, blocksize_b: int = 1, partition_axis: ~eformer.escale.partition.constraints.PartitionAxis = PartitionAxis(   data_parallel_axis : dp   fully_sharded_data_parallel_axis : fsdp   tensor_parallel_axis : tp   sequence_parallel_axis : sp   expert_parallel_axis : ep   batch_axis : ('fsdp', 'dp')   sequence_axis : sp   query_sequence_axis : sp   head_axis : tp   key_sequence_axis : sp   hidden_state_axis : tp   mlp_intermediate_axis : tp   vocab_axis : tp   expert_axis : ep   expert_gate_axis : None   attention_dim_axis : None   bias_head_sequence_axis : None   bias_key_sequence_axis : None   generation_batch_axis : None   generation_query_sequence_axis : None   generation_head_axis : tp   generation_key_sequence_axis : sp   generation_attention_dim_axis : None ), shard_attention_computation: bool = True, use_sharded_kv_caching: bool = False, use_sharding_constraint: bool = False, backend: ~typing.Optional[~easydel.__init__.infra.etils.EasyDeLBackends] = None, platform: ~typing.Optional[~easydel.__init__.infra.etils.EasyDeLPlatforms] = None, easy_method: ~typing.Literal['train', 'serve', 'convert'] = 'train', bits: ~typing.Optional[int] = None, scan_ring_attention: bool = True, scan_attention_layers: bool = False, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, sequence_axis_name: str = 'sp', gradient_checkpointing: ~easydel.__init__.infra.etils.EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, precompute_masks: bool = True, kv_cache_quantization_method: ~easydel.__init__.infra.etils.EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, kv_cache_quantization_blocksize: int = 64, quantization_method: ~easydel.__init__.infra.etils.EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, quantization_pattern: str = '.*', quantization_blocksize: int = 64, kv_cache_sharding_sequence_axis_name: ~typing.Union[str, ~typing.Tuple[str, ...]] = 'sp', flash_attention_backward_pass_impl: ~typing.Literal['triton', 'xla'] = 'triton', attn_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, attn_softmax_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, fcm_max_ratio: float = 0.0, fcm_min_ratio: float = 0.0, hardware_abstraction: bool = False, pallas_m_block_size: int = 128, pallas_k_block_size: int = 128, pallas_n_block_size: int = 128, **kwargs)[source]#

Bases: PretrainedConfig

Initialize the configuration for EasyDeL. :param axis_dims: Dimensions of the axes. Default is (1, -1, 1, 1). :type axis_dims: tp.Sequence[int] :param axis_names: Names of the axes. Default is (“dp”, “fsdp”, “tp”, “sp”). :type axis_names: tp.Sequence[str] :param attn_mechanism: Attention mechanism to use. Default is DEFAULT_ATTENTION_MECHANISM. :type attn_mechanism: AVAILABLE_ATTENTION_MECHANISMS :param blocksize_k: Block size for key. Default is 128. :type blocksize_k: int :param blocksize_q: Block size for query. Default is 128. :type blocksize_q: int :param blocksize_b: Block size for batch. Default is 1. :type blocksize_b: int :param partition_axis: Partition axis configuration. Default is PartitionAxis(). :type partition_axis: PartitionAxis :param shard_attention_computation: Whether to shard attention computation. Default is True. :type shard_attention_computation: bool :param use_sharded_kv_caching: Whether to use sharded key-value caching. Default is False. :type use_sharded_kv_caching: bool :param use_sharding_constraint: Whether to use sharding constraint. Default is False. :type use_sharding_constraint: bool :param backend: Backend to use. Default is None. :type backend: tp.Optional[EasyDeLBackends] :param platform: Platform to use. Default is None. :type platform: tp.Optional[EasyDeLPlatforms] :param easy_method: Method to use. Default is EasyMethod.TRAIN. :type easy_method: tp.Literal[“train”, “serve”, “convert”] :param bits: Number of bits for quantization. Default is None. :type bits: tp.Optional[int] :param scan_ring_attention: Whether to scan ring attention. Default is True. :type scan_ring_attention: bool :param scan_attention_layers: Whether to scan attention layers. Default is False. :type scan_attention_layers: bool :param use_scan_mlp: Whether to use scan MLP. Default is False. :type use_scan_mlp: bool :param scan_mlp_chunk_size: Chunk size for scan MLP. Default is 1024. :type scan_mlp_chunk_size: int :param sequence_axis_name: Name of the attention axis. Default is “sp”. :type sequence_axis_name: str :param gradient_checkpointing: Gradient checkpointing method. Default is EasyDeLGradientCheckPointers.NONE. :type gradient_checkpointing: EasyDeLGradientCheckPointers :param kv_cache_quantization_method: Key-value cache quantization method. Default is EasyDeLQuantizationMethods.NONE. :type kv_cache_quantization_method: EasyDeLQuantizationMethods :param kv_cache_quantization_blocksize: Block size for key-value cache quantization. Default is 64. :type kv_cache_quantization_blocksize: int :param quantization_method: Quantization method. Default is EasyDeLQuantizationMethods.NONE. :type quantization_method: EasyDeLQuantizationMethods :param quantization_pattern: Pattern for quantization. Default is “.*”. :type quantization_pattern: str :param quantization_blocksize: Block size for quantization. Default is 64. :type quantization_blocksize: int :param kv_cache_sharding_sequence_axis_name: Name of the key-value cache sharding sequence axis. Default is “sp”. :type kv_cache_sharding_sequence_axis_name: tp.Union[str, tp.Tuple[str, …]] :param flash_attention_backward_pass_impl: Implementation for flash attention backward pass. Default is “triton”. :type flash_attention_backward_pass_impl: tp.Literal[“triton”, “xla”] :param attn_dtype: Data type for attention. Default is device half. :type attn_dtype: jnp.dtype :param attn_softmax_dtype: Data type for softmax ops in attention. Default is jnp.float32. :type attn_softmax_dtype: jnp.dtype :param fcm_max_ratio: Maximum ratio for FCM. Default is 0.0. :type fcm_max_ratio: float :param fcm_min_ratio: Minimum ratio for FCM. Default is 0.0. :type fcm_min_ratio: float :param hardware_abstraction: Whether to use hardware abstraction. Default is DEFAULT_HARDWARE_ABSTRACTION. :type hardware_abstraction: bool :param pallas_m_block_size: Block size for Pallas M. Default is DEFAULT_PALLAS_M_BLOCK_SIZE. :type pallas_m_block_size: int :param pallas_k_block_size: Block size for Pallas K. Default is DEFAULT_PALLAS_K_BLOCK_SIZE. :type pallas_k_block_size: int :param pallas_n_block_size: Block size for Pallas N. Default is DEFAULT_PALLAS_N_BLOCK_SIZE. :type pallas_n_block_size: int :param **kwargs: Additional keyword arguments.

Raises

Warning – If kv_cache_quantization_method is not NONE and use_sharded_kv_caching is True.

add_basic_configurations(axis_dims: Sequence[int] = Ellipsis, dcn_axis_dims: Optional[Sequence[int]] = Ellipsis, axis_names: Sequence[str] = Ellipsis, attn_mechanism: Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa'] = Ellipsis, blocksize_k: int = Ellipsis, blocksize_q: int = Ellipsis, blocksize_b: int = Ellipsis, partition_axis: PartitionAxis = Ellipsis, shard_attention_computation: bool = Ellipsis, use_sharded_kv_caching: bool = Ellipsis, backend: Optional[EasyDeLBackends] = Ellipsis, platform: Optional[EasyDeLPlatforms] = Ellipsis, easy_method: Literal['train', 'serve', 'convert'] = Ellipsis, bits: Optional[int] = Ellipsis, scan_ring_attention: bool = Ellipsis, scan_attention_layers: bool = Ellipsis, use_sharding_constraint: bool = Ellipsis, use_scan_mlp: bool = Ellipsis, scan_mlp_chunk_size: int = Ellipsis, sequence_axis_name: str = Ellipsis, gradient_checkpointing: EasyDeLGradientCheckPointers = Ellipsis, precompute_masks: bool = Ellipsis, kv_cache_quantization_method: EasyDeLQuantizationMethods = Ellipsis, kv_cache_quantization_blocksize: int = Ellipsis, quantization_method: EasyDeLQuantizationMethods = Ellipsis, quantization_blocksize: int = Ellipsis, quantization_pattern: str = Ellipsis, kv_cache_sharding_sequence_axis_name: Union[str, Tuple[str, ...]] = Ellipsis, flash_attention_backward_pass_impl: Literal['triton', 'xla'] = Ellipsis, attn_dtype: dtype = Ellipsis, attn_softmax_dtype: dtype = Ellipsis, hardware_abstraction: bool = Ellipsis, pallas_m_block_size: int = Ellipsis, pallas_k_block_size: int = Ellipsis, pallas_n_block_size: int = Ellipsis, **kwargs)[source]#

It initializes all the attributes of an object, and it’s called when you create a new instance of that class.

Parameters
  • axis_dims (tp.Sequence[int], optional) – Specify the number of dimensions for each axis. Defaults to (1, -1, 1, 1).

  • axis_names (tp.Sequence[str], optional) – Set the names of the axes. Defaults to (“dp”, “fsdp”, “tp”, “sp”).

  • attn_mechanism (AVAILABLE_ATTENTION_MECHANISMS, optional) – attention mechanism to use. Defaults to DEFAULT_ATTENTION_MECHANISM.

  • blocksize_k (int, optional) – block size of key_states. Defaults to 128.

  • blocksize_q (int, optional) – block size of query_states. Defaults to 128.

  • blocksize_b (int, optional) – block size of bias. Defaults to 1.

  • partition_axis (PartitionAxis, optional) – PartitionAxis is new module used for partitioning arrays in easydel. Defaults to PartitionAxis().

  • shard_attention_computation (bool, optional) – whenever to use shard_map for attention. Defaults to True.

  • use_sharded_kv_caching (bool, optional) – whenever to use shard_map and sharding for key and value. Defaults to True.

  • backend (tp.Optional[EasyDeLBackends], optional) – Specify the backend to use. Defaults to None.

  • platform (tp.Optional[EasyDeLPlatforms], optional) – Specify the platform to used to use. Defaults to None.

  • easy_method (tp.Literal["train", "serve", "convert"], optional) – easydel Quantization Method to be applied for. Defaults to EasyMethod.TRAIN.

  • bits (tp.Optional[int], optional) – Model bits for quantization. Defaults to None.

  • scan_ring_attention (bool, optional) – Whether to use can for ring attention. Defaults to True.

  • scan_attention_layers (bool, optional) – Whether to use can for attention layers. Defaults to False.

  • use_sharding_constraint (bool, optional) – whether to use sharding constraint for the arrays. Defaults to False.

  • use_scan_mlp (bool, optional) – Determine whether to use scan_mlp or not. Defaults to False.

  • scan_mlp_chunk_size (int, optional) – Size of chunks in scan MLP. Defaults to 1024.

  • sequence_axis_name (str, optional) – Name of the attention axis name. Defaults to “sp”.

  • gradient_checkpointing (EasyDeLQuantizationMethods, optional) – Gradient Checkpointing method for created or loaded module (applied on mlp and attn layers most of the times).

  • kv_cache_quantization_method (EasyDeLQuantizationMethods, optional) – key and value quantization type. Defaults to EasyDeLQuantizationMethods.NONE.

  • kv_cache_quantization_blocksize (int, optional) – size of kv cache quantization. Defaults to 64.

  • quantization_method (EasyDeLQuantizationMethods, optional) – linear modules quantization type. Defaults to EasyDeLQuantizationMethods.NONE.

  • quantization_blocksize (int, optional) – size of linear quantization. Defaults to 64.

  • quantization_pattern (str) – re pattern to be used for quantizing layers.

  • kv_cache_sharding_sequence_axis_name (tp.Union[str, tp.Tuple[str, ...]], optional) – axis name to target for sharding sequences. Defaults to “sp”.

  • flash_attention_backward_pass_impl (tp.Literal["triton", "xla"], optional) – Specify the backward pass kernel for flash attention. Defaults to “triton”.

  • attn_dtype (jnp.dtype, optional) – Data type for attention computations. Defaults to device half.

  • attn_softmax_dtype (jnp.dtype, optional) – Data type for softmax in attention op computations. Defaults to jnp.float32.

  • fcm_max_ratio (float, optional) – Maximum ratio for flash cross attention. Defaults to 0.0.

  • fcm_min_ratio (float, optional) – Minimum ratio for flash cross attention. Defaults to 0.0.

  • hardware_abstraction (bool, optional) – whenever to switch to custom pallas kernels instead of JAX. Defaults to DEFAULT_HARDWARE_ABSTRACTION.

  • pallas_m_block_size (int, optional) – block size m dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_M_BLOCK_SIZE.

  • pallas_k_block_size (int, optional) – block size k dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_K_BLOCK_SIZE.

  • pallas_n_block_size (int, optional) – block size n dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_N_BLOCK_SIZE.

attach_custom_arguments(**kwargs)[source]#
static create_mesh(axis_dims: Sequence[int] = (1, -1, 1, 1), axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), dcn_axis_dims: Optional[Sequence[int]] = None, process_is_granule: bool = False, should_sort_granules_by_key: bool = True, allow_split_physical_axes: bool = True, backend: Optional[str] = None)[source]#

The create_mesh function creates a mesh object that can be used to shard arrays.

Returns

A mesh object

classmethod from_pretrained(pretrained_model_name_or_path: Union[str, PathLike], cache_dir: Optional[Union[str, PathLike]] = None, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = 'main', **kwargs) PretrainedConfig[source]#

Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.

Parameters
  • pretrained_model_name_or_path (str or os.PathLike) –

    This can be either:

    • a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.

    • a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.

    • a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.

  • cache_dir (str or os.PathLike, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.

  • force_download (bool, optional, defaults to False) – Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.

  • resume_download – Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.

  • proxies (Dict[str, str], optional) – A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.

  • token (str or bool, optional) – The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).

  • revision (str, optional, defaults to “main”) –

    The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.

    <Tip>

    To test a pull request you made on the Hub, you can pass `revision=”refs/pr/<pr_number>”.

    </Tip>

  • return_unused_kwargs (bool, optional, defaults to False) –

    If False, then this function returns just the final configuration object.

    If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.

  • subfolder (str, optional, defaults to “”) – In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.

  • kwargs (Dict[str, tp.Any], optional) – The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.

Returns

The configuration object instantiated from this pretrained model.

Return type

[PretrainedConfig]

Examples:

>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
>>> # derived class: BertConfig
>>> config = BertConfig.from_pretrained(
...   "google-bert/bert-base-uncased"
>>> )  # Download configuration from huggingface.co and cache.
>>> config = BertConfig.from_pretrained(
...   "./test/saved_model/"
>>> )  # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
>>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
>>> config = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased", output_attentions=True, foo=False
>>> )
>>> assert config.output_attentions == True
>>> config, unused_kwargs = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased",
...  output_attentions=True,
...  foo=False,
...  return_unused_kwargs=True,
>>> )
>>> assert config.output_attentions == True
>>> assert unused_kwargs == {"foo": False}

```

get_axis_dims() Sequence[int][source]#

The get_axis_dims function returns a sequence of integers representing the dimensions of each axis.

Parameters

self – Represent the instance of the class

Returns

The dimensions of the axes

get_axis_names() Sequence[str][source]#

The get_axis_names function returns a list of the names of the axes.

Parameters

self – Represent the instance of the class

Returns

A list of the names of all axes

get_backend() str[source]#

The get_backend function returns the backend that is currently being used. If no backend has been set, it will return the default JAX backend.

Parameters

self – Bind the method to an object

Returns

The backend platform

get_basic_causal_mask(*args, **kwargs)[source]#
get_basic_frequencies(head_size: Optional[int] = None, rotary_dim: Optional[int] = None, base: Optional[float] = None) Any[source]#

Get basic frequencies for rotary embeddings.

Parameters
  • head_size – Size of attention heads (defaults to self.head_dim)

  • rotary_dim – Dimension for rotary embeddings (defaults to head_size)

  • base – Base value for frequency computation (defaults to self.rope_theta)

Returns

ModuleCaches instance containing computed frequencies

get_basic_inv_frequencies(head_size: Optional[int] = None, rotary_dim: Optional[int] = None, base: Optional[float] = None, partial_rotary_factor: float = 1.0) Any[source]#

Get basic inv frequencies for rotary embeddings.

Parameters
  • head_size – Size of attention heads (defaults to self.head_dim)

  • rotary_dim – Dimension for rotary embeddings (defaults to head_size)

  • base – Base value for frequency computation (defaults to self.rope_theta)

Returns

ModuleCaches instance containing computed frequencies

get_basic_rope(dtype: Union[Array, ndarray, bool, number], head_size: int, rotary_dim: Optional[int] = None, is_neox_style: bool = True, base: Optional[float] = None)[source]#

Get basic rotary position embeddings.

Parameters
  • dtype – Data type for the embeddings

  • head_size – Size of attention heads

  • rotary_dim – Dimension for rotary embeddings (defaults to head_size)

  • is_neox_style – Whether to use NeoX style embeddings

  • base – Base value for frequency computation (defaults to self.rope_theta)

Returns

Rotary position embeddings func

get_fcm_mask(batch_size, seq_length, deterministic: bool)[source]#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
jax_mesh()[source]#
property mesh#

The mesh property is a helper property that creates a Mesh object from the axis_dims and axis_names attributes of an object, which are assumed to be lists of integers and strings, respectively. The platform attribute is also used if it exists.

Parameters

self – Refer to the object itself

Returns

A jaxMesh

read_basics_from_config(config: EasyDeLBaseConfig)[source]#
to_dict() Dict[str, Any][source]#

Serializes this instance to a Python dictionary.

Returns

Dictionary of all the attributes that make up this configuration instance.

Return type

Dict[str, Any]

class easydel.__init__.EasyDeLBaseConfigDict[source]#

Bases: TypedDict

class easydel.__init__.EasyDeLBaseModule(*args: Any, **kwargs: Any)[source]#

Bases: Module, BaseModuleProtocol, EasyBridgeMixin, EasyGenerationMixin

Base class for EasyDeL modules, providing common functionalities for model initialization, parameter handling, and integration with the EasyDeL ecosystem.

apply_lora_to_layers(lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = False, rngs: Optional[Rngs] = None) SELF[source]#

Applies Low-Rank Adaptation (LoRA) layers to the specified linear layers within the module.

Replaces targeted flax.linen.Dense layers with easydel.layers.lora.LoraLinear layers, initializing the LoRA matrices (A and B).

Parameters
  • lora_rank (int) – The rank of the LoRA decomposition.

  • lora_pattern (tp.Optional[str], optional) – A regular expression to match the names of the Dense layers to apply LoRA to. If None, applies to common attention and MLP layers. Defaults to None.

  • verbose (bool, optional) – If True, prints information about which layers are being modified. Defaults to False.

  • rngs (tp.Optional[nn.Rngs], optional) – JAX random number generators for initializing LoRA matrices. If None, default RNGs might be used. Defaults to None.

Returns

The module instance with LoRA layers applied.

Return type

SELF

property causal_mask: Array#

Retrieves or computes the basic causal attention mask from the configuration.

Uses self.config.get_basic_causal_mask() and caches the result.

Returns

The causal attention mask, potentially cached.

Return type

jnp.ndarray

compute_complex_rotary(position_ids: Array) Array[source]#
compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#

Computes the loss for the model given a batch of inputs and labels.

This method performs a forward pass using the provided batch arguments, then calculates the loss using the determined loss_function. It handles potential label inference (e.g., using input_ids as labels for Causal LM) and default loss configurations.

Parameters
  • labels (tp.Optional[chex.Array], optional) – The target labels. If None and the task is Causal LM, input_ids from the batch might be used. Defaults to None.

  • loss_config (tp.Optional[LossConfig], optional) – Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.

  • loss_kwargs (tp.Optional[tp.Dict], optional) – Additional keyword arguments to pass directly to the loss function. Defaults to None.

  • **batch – Keyword arguments representing the input batch (e.g., input_ids, attention_mask).

Returns

A tuple containing:
  • The model’s output ( Pytree typically including logits, hidden states etc.)

  • A LossMetrics object containing the calculated loss and potentially other metrics.

Return type

tp.Tuple[tp.Any, LossMetrics]

Raises
  • AssertionError – If labels are required for the loss function but are not provided or inferred.

  • AssertionError – If sequence classification loss is used without num_labels in the config.

float(change_runtime_dtype: bool = True) SELF[source]#

Converts the module’s parameters to single-precision (float32).

Optionally also changes the runtime computation dtype (self.dtype) to float32.

Parameters

change_runtime_dtype (bool) – If True, also sets self.dtype to jnp.float32. Defaults to True.

Returns

The module instance with parameters (and potentially runtime dtype) set to float32.

Return type

SELF

property frequencies: Array#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

fully_gather() SELF[source]#

Applies JAX sharding constraints to gather all parameters onto the host or a single device.

This function marks all parameters to have no sharding (PartitionSpec()). It uses jax.jit with out_shardings to enforce these gathering constraints.

Returns

The model instance with gathering constraints applied.

Return type

SELF

fully_shard(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None) SELF[source]#

Applies JAX sharding constraints to all parameters based on the partition rules.

This function ensures that parameters are explicitly marked with their intended sharding, which can be useful for performance and correctness checks. It uses jax.jit with out_shardings to enforce the constraints.

Parameters

partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses config rules. Defaults to None.

Returns

The model instance with sharding constraints applied.

Return type

SELF

gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#

Gathers the model’s parameters from potentially distributed devices to the host or a single device.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules used to determine how parameters were originally sharded. If None, uses config rules. Defaults to None.

  • mesh (tp.Optional[Mesh], optional) – JAX device mesh from which to gather. If None, uses config mesh. Defaults to None.

  • overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional) – Additional functions to apply, potentially overriding default gathering for specific parameters. Defaults to None.

Returns

The model instance with gathered parameters.

Return type

SELF

get_static_arguments() Tuple[source]#

Returns a tuple of static arguments required by the module’s __call__ method.

Static arguments are those that don’t change across calls and can be potentially cached or handled differently by JIT compilation. This base implementation returns an empty tuple. Subclasses should override this if they have static arguments.

Returns

A tuple containing static arguments.

Return type

tp.Tuple

property graphdef: Union[NodeDef[Node], NodeRef[Node]]#

Returns the graph definition (structure without parameters) of the module.

Uses flax.nnx.split to separate the graph definition from the state (parameters).

Returns

The graph definition of the module.

Return type

nn.GraphDef

property graphother: State[Key, VariableState[Any]]#

Returns any other state variables in the module (non-parameters).

Uses flax.nnx.split to separate non-parameter state variables.

Returns

The graph state containing non-parameter variables.

Return type

nn.GraphState

property graphstate: State[Key, VariableState[Any]]#

Returns the graph state (parameters) of the module.

Uses flax.nnx.split to separate the state (parameters) from the graph definition.

Returns

The graph state containing the module’s parameters.

Return type

nn.GraphState

property graphtree_params_shape: Dict#

Computes and returns the shapes of the module’s parameters as a nested dictionary.

It uses nnx.eval_shape to determine the shapes without actual computation, then extracts the shape information from the resulting graph state.

Returns

A nested dictionary mirroring the parameter structure, containing their shapes.

Return type

tp.Dict

property graphtree_shape: Dict#

Computes and returns the shapes of all state variables (including non-parameters) in the module.

Uses nnx.eval_shape on the entire module state (parameters and others) and extracts the shape information.

Returns

A nested dictionary mirroring the module’s state structure, containing the shapes.

Return type

tp.Dict

half(change_runtime_dtype: bool = True) SELF[source]#

Converts the module’s parameters to half-precision (float16).

Optionally also changes the runtime computation dtype (self.dtype) to float16.

Parameters

change_runtime_dtype (bool) – If True, also sets self.dtype to jnp.float16. Defaults to True.

Returns

The module instance with parameters (and potentially runtime dtype) set to float16.

Return type

SELF

property inv_frequencies: Array#

Retrieves or computes the inv-frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_inv_frequencies() and caches the result.

Returns

The inv-frequency components, potentially cached.

Return type

jnp.ndarray

classmethod lazy_init(*args, **kwargs) SELF[source]#

Performs a “lazy” initialization using nnx.eval_shape.

This initializes the module structure and determines parameter shapes without actually allocating memory for the parameters. Useful for inspecting the model structure or preparing for sharding.

Parameters
  • *args – Positional arguments passed to the class constructor.

  • **kwargs – Keyword arguments passed to the class constructor.

Returns

A module instance with initialized structure but potentially abstract parameters.

Return type

SELF

property loss_function#

Determines and returns the appropriate loss function based on the configuration or model type.

It prioritizes config.loss_type, then self.loss_type, and finally tries to infer the loss type from the class name. If no suitable loss function is found, it defaults to ForCausalLMLoss and issues a warning.

Returns

The selected loss function (e.g., ForCausalLMLoss, ForSequenceClassificationLoss).

Return type

tp.Callable

merge_lora_params(pytree: Dict) SELF[source]#

Merges LoRA parameters from a pytree into the base model’s parameters.

Parameters

pytree (tp.Dict) – A dictionary (pytree) containing the LoRA parameters (A and B matrices) structured similarly to the base model’s parameters.

Returns

The module instance with LoRA parameters merged into the base weights.

Return type

SELF

merge_params(tree)[source]#

Merges a given parameter state tree back into the module.

Reconstructs the module using its existing graph definition and ‘other’ state, but replaces the parameter state with the provided tree.

Parameters

tree – A pytree (likely a nn.GraphState) containing the parameters to merge.

Returns

The module instance with the new parameters merged in.

Return type

EasyDeLBaseModule

merge_params_dict(params_dict: Dict) SELF[source]#

Merges parameters from a dictionary back into the module’s state.

Updates the module’s current parameter state with values from the provided dictionary.

Parameters

params_dict (tp.Dict) – A nested dictionary containing the parameters to merge. The structure should match the module’s parameter structure.

Returns

The module instance with the parameters from the dictionary merged in.

Return type

SELF

Raises

KeyError – If a key from params_dict is not found in the module’s current state.

property mesh: Mesh#

Retrieves the JAX device mesh from the module’s configuration.

Returns

The device mesh defined in self.config.mesh.

Return type

jax.sharding.Mesh

property model_task: Optional[str]#

Returns the specific task associated with this model instance (e.g., ‘causal-language-model’).

Returns

The model task identifier, or None if not set.

Return type

tp.Optional[str]

property model_type: Optional[str]#

Returns the specific type of this model instance (e.g., ‘llama’, ‘mistral’).

Returns

The model type identifier, or None if not set.

Return type

tp.Optional[str]

property module_dtype: dtype#

Determines the data type of the module’s parameters.

It inspects the flattened parameter state to find the dtype of the first parameter encountered.

Returns

The data type of the module’s parameters.

Return type

jnp.dtype

property parameters: Dict#

Retrieves the parameters of the module as a dictionary.

This property iterates through the module and its submodules, extracting variables marked as nn.Param and returning them in a flat dictionary where keys represent the parameter path.

Returns

A dictionary containing the module’s parameters.

Return type

tp.Dict

property params: Dict#

Returns the parameters and other state variables of the module as a dictionary.

Uses flax.nnx.split to get the combined state (parameters and others).

Returns

A dictionary containing all state variables of the module.

Return type

tp.Dict

property params_sharding: Dict#

Retrieves the sharding annotation for each parameter in the module.

Returns

A nested dictionary mirroring the parameter structure, containing the

sharding information (e.g., NamedSharding, PartitionSpec) for each parameter, or None if unsharded.

Return type

tp.Dict

prepare_inputs_for_call(**kwargs)[source]#

Prepares keyword arguments before passing them to the module’s __call__ method.

This base implementation simply returns the kwargs as is. Subclasses can override this to modify or add arguments as needed (e.g., for generation).

Parameters

**kwargs – The keyword arguments intended for __call__.

Returns

The prepared keyword arguments.

Return type

dict

property pure_transform_fn#

Returns a pure transformation function for PyTorch state dicts to EasyDeL parameters.

Similar to transform_fn, but this version does not include sharding functions. It identifies embedding and LayerNorm layers and returns a partial function (torch_dict_to_easydel_params) configured only with layer names and dtype.

Returns

A partial function for converting a PyTorch state dict without applying sharding.

Return type

tp.Callable

quantize(method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT, block_size: int = 128, quantization_pattern: Optional[str] = None, quantize_tensors: bool = True, verbose: Optional[bool] = None) SELF[source]#

Applies quantization to the module’s linear layers or tensors.

Parameters
  • method (EasyDeLQuantizationMethods, optional) – The quantization algorithm to use (e.g., A8BIT, NF4). Defaults to EasyDeLQuantizationMethods.A8BIT.

  • block_size (int, optional) – The block size for quantization methods that support it. Defaults to 128.

  • quantization_pattern (tp.Optional[str], optional) – A regular expression to match parameter names that should be quantized. If None, uses a default pattern. Defaults to None.

  • quantize_tensors (bool, optional) – If True, quantizes the tensor values directly. If False (currently default behavior in implementation), replaces Linear layers with their quantized equivalents. Defaults to True (though implementation differs).

  • verbose (tp.Optional[bool], optional) – If True, logs information during the quantization process. Defaults to True only on process index 0.

Returns

The quantized model instance.

Return type

SELF

shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#

Shards the model’s parameters according to the specified rules and mesh.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses config rules. Defaults to None.

  • mesh (tp.Optional[Mesh], optional) – JAX device mesh. If None, uses config mesh. Defaults to None.

  • overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional) – Additional functions to apply, potentially overriding default sharding for specific parameters. Defaults to None.

Returns

The sharded model instance.

Return type

SELF

split_lora_params() Dict[source]#

Splits merged LoRA parameters back out from the base model’s parameters.

This function assumes LoRA parameters were previously merged using merge_lora_params or a similar process that stored the original base weights and LoRA weights appropriately.

Returns

A pytree containing the extracted LoRA parameters (A and B matrices).

The base model parameters are restored to their original (pre-merge) state.

Return type

tp.Dict

split_params()[source]#

Splits the module and returns the parameter state.

Uses nnx.split to extract the GraphState containing the parameters.

Returns

The parameter state of the module.

Return type

nn.GraphState

split_params_dict(extract_fn: Optional[Callable] = None, remove_none: bool = True) Dict[source]#

Splits the module parameters and returns them as a nested dictionary.

Extracts the parameter state, converts it to a plain dictionary (removing VariableState wrappers), and optionally removes entries with None values.

Parameters
  • extract_fn (tp.Optional[tp.Callable], optional) – A function to apply to each parameter during extraction. Defaults to None.

  • remove_none (bool, optional) – If True, removes key-value pairs where the value is None. Defaults to True.

Returns

A nested dictionary containing the module’s parameters.

Return type

tp.Dict

property static_arguments: Tuple#

Retrieves or computes static arguments needed for the module’s __call__ method.

Uses self.get_static_arguments() and caches the result. Static arguments are typically those that don’t change during execution and can be pre-computed.

Returns

A tuple of static arguments.

Return type

tp.Tuple

to_dtype(dtype: dtype) SELF[source]#

Converts the module’s parameters to the specified data type.

It iterates through the module’s parameters (excluding quantization-related ones) and casts them to the target dtype. It also updates the param_dtype attribute of the module and its submodules if they exist.

Parameters

dtype (jnp.dtype) – The target data type for the parameters.

Returns

The module instance with parameters converted to the specified dtype.

Return type

SELF

to_state() Any[source]#

Converts the current module instance into an EasyDeLState object.

This is useful for saving and managing the model’s state, including parameters and potentially optimizer state (though optimizer state is typically added later).

Returns

An EasyDeLState object representing the current model state.

Return type

EasyDeLState

to_torch(**kwargs)[source]#

Converts the EasyDeL module to its equivalent Hugging Face PyTorch model.

Requires the corresponding PyTorch model class to be available and registered. Uses utility functions to transfer parameters from JAX to PyTorch format.

Parameters

**kwargs – Additional keyword arguments passed to the parameter transformation function.

Returns

The equivalent Hugging Face PyTorch model with loaded weights.

Return type

torch.nn.Module

property transform_fn#

Returns a partial function for transforming PyTorch state dicts to EasyDeL parameters.

This function identifies embedding and LayerNorm layers within the module and creates a transformation function (torch_dict_to_easydel_params) pre-configured with these layer names, the target parameter dtype, and the module’s sharding functions.

Returns

A partial function ready to convert a PyTorch state dict.

Return type

tp.Callable

unwrap_lora_to_layers(verbose: bool = False) SELF[source]#

Reverts the application of LoRA layers, restoring the original linear layers.

Replaces easydel.layers.lora.LoraLinear layers with their original flax.linen.Dense counterparts, discarding the LoRA matrices.

Parameters

verbose (bool, optional) – If True, prints information about which layers are being reverted. Defaults to False.

Returns

The module instance with LoRA layers removed and original layers restored.

Return type

SELF

class easydel.__init__.EasyDeLGradientCheckPointers(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enumeration of gradient checkpointing strategies available in EasyDeL.

Gradient checkpointing is a technique to reduce memory usage during training by recomputing activations during the backward pass instead of storing them.

EVERYTHING_SAVEABLE#

Checkpoints residuals, attentions, and hidden states. This is the most memory-intensive checkpointing strategy.

NOTHING_SAVEABLE#

Checkpoints only the residuals. This strategy saves the most memory but requires more recomputation.

CHECKPOINT_DOTS#

Checkpoints matrix multiplications and intermediate activations.

CHECKPOINT_DOTS_WITH_NO_BATCH_DMIS#

Similar to CHECKPOINT_DOTS but avoids checkpointing operations involving batch dimensions.

NONE#

No gradient checkpointing is applied.

CHECKPOINT_DOTS = 'checkpoint_dots'#
CHECKPOINT_DOTS_WITH_NO_BATCH_DMIS = 'checkpoint_dots_with_no_batch_dims'#
EVERYTHING_SAVEABLE = 'everything_saveable'#
NONE = ''#
NOTHING_SAVEABLE = 'nothing_saveable'#
class easydel.__init__.EasyDeLOptimizers(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enumeration of available optimizers in the EasyDeL library.

ADAFACTOR#

Represents the Adafactor optimizer.

LION#

Represents the Lion optimizer.

ADAMW#

Represents the AdamW optimizer.

RMSPROP#

Represents the RMSprop optimizer.

ADAFACTOR = 'adafactor'#
ADAMW = 'adamw'#
LION = 'lion'#
RMSPROP = 'rmsprop'#
class easydel.__init__.EasyDeLPlatforms(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enumeration of platforms or kernel execution backends supported by EasyDeL.

This allows selecting optimized kernel implementations for different hardware or software environments.

JAX#

Use standard JAX kernel implementations.

TRITON#

Use Triton-based kernel implementations (often for GPUs).

PALLAS#

Use Pallas-based kernel implementations (often for TPUs).

JAX = 'jax'#
PALLAS = 'pallas'#
TRITON = 'triton'#
class easydel.__init__.EasyDeLQuantizationMethods(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enumeration of quantization methods supported by EasyDeL.

Quantization reduces the precision of model weights and/or activations to save memory and potentially speed up inference.

NONE#

No quantization is applied.

NF4#

Represents NormalFloat 4-bit quantization.

A8BIT#

Represents 8-bit affine quantization.

A8BIT = '8bit'#
NF4 = 'nf4'#
NONE = 'None'#
exception easydel.__init__.EasyDeLRuntimeError[source]#

Bases: Exception

class easydel.__init__.EasyDeLSchedulers(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enumeration of available learning rate schedulers in EasyDeL.

NONE#

Indicates no scheduler should be used.

LINEAR#

Represents a linear learning rate decay scheduler.

COSINE#

Represents a cosine annealing learning rate scheduler.

COSINE = 'cosine'#
LINEAR = 'linear'#
NONE = 'None'#
class easydel.__init__.EasyDeLState(step: int | jax.Array, graphdef: nn.GraphDef, graphstate: nn.GraphState, graphother: nn.GraphState, tx: optax.GradientTransformation, opt_state: tp.Optional[optax.OptState], apply_fn: tp.Optional[tp.Callable] = None)[source]#

Bases: PyTreeNode

Represents the state of an EasyDeL model during training or inference.

This class encapsulates the model’s parameters, optimizer state, training step, and potentially other metadata. It provides methods for applying gradients, managing sharding, saving, and loading the state.

step#

The current training step count.

Type

int | jax.Array

graphdef#

The definition of the model’s computation graph (structure).

Type

nn.GraphDef

graphstate#

The state of the model’s parameters.

Type

nn.GraphState

graphother#

The state of non-parameter variables within the model.

Type

nn.GraphState

tx#

The optimizer transformation (e.g., AdamW, SGD). Marked as a non-pytree node.

Type

optax.GradientTransformation

opt_state#

The state of the optimizer (e.g., moments). Marked as a pytree node.

Type

tp.Optional[optax.OptState]

apply_fn#

A function to apply the model (often model.__call__). Typically not directly part of the state but can be associated.

Type

tp.Optional[tp.Callable]

apply_fn: tp.Optional[tp.Callable] = None#
apply_gradients(*, grads)[source]#

Updates the model’s parameters and optimizer state based on calculated gradients.

Parameters

grads – A pytree matching the structure of self.graphstate containing the gradients.

Returns

A new state object with the updated parameters (graphstate),

optimizer state (opt_state), and incremented step count.

Return type

EasyDeLState

Raises

AssertionError – If opt_state or tx is not initialized.

classmethod create(*, step: tp.Optional[int] = None, graphdef: tp.Optional[nn.GraphDef] = None, graphstate: tp.Optional[nn.GraphState] = None, graphother: tp.Optional[nn.GraphState] = None, model: tp.Optional[nn.Module] = None, tx: tp.Optional[optax.GradientTransformation] = None, opt_state: tp.Optional[optax.OptState] = None, init_opt_state: bool = False) EasyDeLState[source]#

Creates a new EasyDeLState instance.

This class method provides a flexible way to initialize the state, either from an existing nn.Module or by providing the graph components (graphdef, graphstate, graphother) directly. It also handles optimizer state initialization.

Parameters
  • step (tp.Optional[int]) – The initial training step. Defaults to 0.

  • graphdef (tp.Optional[nn.GraphDef]) – The model’s graph definition.

  • graphstate (tp.Optional[nn.GraphState]) – The model’s parameter state.

  • graphother (tp.Optional[nn.GraphState]) – The model’s non-parameter state.

  • model (tp.Optional[nn.Module]) – An EasyDeL module instance. If provided, graphdef, graphstate, and graphother are derived from it. Cannot be provided simultaneously with graph components.

  • tx (tp.Optional[optax.GradientTransformation]) – The optimizer transformation.

  • opt_state (tp.Optional[optax.OptState]) – The initial optimizer state. Cannot be provided if init_opt_state is True.

  • init_opt_state (bool) – If True, initializes the optimizer state using tx.init(graphstate). Requires tx to be provided. Defaults to False.

Returns

A new instance of the state.

Return type

EasyDeLState

Raises
  • ValueError – If model and graph components are provided simultaneously.

  • ValueError – If graph components are provided partially.

  • ValueError – If init_opt_state is True and opt_state is also provided.

  • ValueError – If init_opt_state is True but tx is not provided.

gather_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#

Gathers the model parameters (graphstate and graphother) from distributed devices.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules used for the original sharding. If None, uses model config rules. Defaults to None.

  • mesh (tp.Optional[Mesh], optional) – The JAX device mesh to gather from. If None, uses model’s mesh. Defaults to None.

Returns

A new state object with gathered graphstate and graphother.

Return type

EasyDeLState

gather_optimizer_state(partition_rules=None)[source]#

Gathers the optimizer state from potentially distributed devices.

Parameters

partition_rules (PartitionLike, optional) – Partitioning rules used to determine how the state was sharded. If None, uses rules from the model’s config. Defaults to None.

Returns

A new state object with the gathered opt_state.

Return type

EasyDeLState

Raises

AssertionError – If opt_state is not initialized.

gather_state()[source]#

Gathers the entire state (model parameters and optimizer state) from distributed devices.

This is a convenience method that calls gather_model and gather_optimizer_state.

Returns

A new state object with both model and optimizer states gathered.

Return type

EasyDeLState

graphdef: nn.GraphDef#
graphother: nn.GraphState#
graphstate: nn.GraphState#
init_tx(tx: GradientTransformation, partition_rules: Any = None) EasyDeLState[source]#

Initializes the optimizer state (opt_state) for the current graphstate using the provided optimizer transformation (tx). It automatically handles sharding based on the model’s partition rules.

Parameters
  • tx (optax.GradientTransformation) – The optimizer transformation to initialize with.

  • partition_rules (PartitionLike, optional) – Partitioning rules for the optimizer state. If None, uses the rules from the associated model’s config. Defaults to None.

Returns

A new state object with the initialized and potentially sharded

opt_state and the provided tx.

Return type

EasyDeLState

load_optimizer(load_directory: Union[str, PathLike])[source]#

Loads the optimizer state from saved files.

Reads the optimizer state structure from a pickle file (OPTIMIZER_STRUCT_NAME) and the tensor data from a SafeTensors file (OPTIMIZER_NAME) within the specified directory.

Parameters

load_directory (tp.Union[str, os.PathLike]) – The directory containing the saved optimizer state files.

Returns

A new state object with the loaded opt_state.

Return type

EasyDeLState

Raises
  • FileNotFoundError – If the required optimizer files are not found.

  • Exception – If any error occurs during loading or deserialization.

load_state(load_directory: Union[str, PathLike], verbose: bool = True)[source]#
merge(tree) Any[source]#

Merges a given state tree (usually parameters) with the graph definition and other state components to reconstruct the full model module.

Parameters

tree – The pytree (e.g., nn.GraphState) containing the parameters to merge.

Returns

The reconstructed model module.

Return type

EasyDeLBaseModule

merge_to_state(tree) EasyDeLState[source]#

Creates a new EasyDeLState by replacing the current graphstate with the provided tree.

Parameters

tree – The pytree (e.g., nn.GraphState) containing the new parameters.

Returns

A new state object with the updated graphstate.

Return type

EasyDeLState

property model: Any#

Reconstructs and returns the full EasyDeL model module from the state components.

Returns

The model module instance.

Return type

EasyDeLBaseModule

opt_state: tp.Optional[optax.OptState]#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

save_state(save_directory: Union[str, PathLike], float_dtype: Optional[dtype] = None, verbose: bool = True, mismatch_allowed: bool = True, save_optimizer: bool = True, enable: Optional[bool] = None)[source]#

Saves the entire EasyDeLState to a directory.

This includes saving the model parameters (using model.save_pretrained) and optionally the optimizer state.

Parameters
  • save_directory (tp.Union[str, os.PathLike]) – The directory to save the state to.

  • float_dtype (tp.Optional[jax.numpy.dtype]) – Optional dtype to cast floating-point parameters to before saving. Defaults to None.

  • verbose (bool) – If True, logs information during saving. Defaults to True.

  • mismatch_allowed (bool) – Passed to model.save_pretrained, allows saving even if the model structure differs slightly from expected. Defaults to True.

  • save_optimizer (bool) – If True, saves the optimizer state. Defaults to True.

  • enable (tp.Optional[bool]) – If set, controls whether saving happens (True) or is skipped (False). If None, saving typically occurs only on JAX process index 0. Defaults to None.

shard_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#

Shards the model parameters (graphstate and graphother) based on partition rules.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses model config rules. Defaults to None.

  • mesh (tp.Optional[Mesh], optional) – The JAX device mesh to shard across. If None, uses model’s mesh. Defaults to None.

Returns

A new state object with sharded graphstate and graphother.

Return type

EasyDeLState

shard_optimizer_state(opt_state: Optional[Any] = None, partition_rules: Any = None) Any[source]#

Applies sharding to the optimizer state based on partition rules.

Parameters
  • opt_state (tp.Optional[tp.Any]) – The optimizer state pytree to shard. If None, uses self.opt_state. Defaults to None.

  • partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses rules from the model’s config. Defaults to None.

Returns

A new state object with the sharded opt_state.

Return type

EasyDeLState

Raises

ValueError – If optimizer state is not initialized (neither opt_state argument nor self.opt_state is available).

shard_state(partition_rules: Any = None) EasyDeLState[source]#

Shards the entire state (model parameters and optimizer state) based on partition rules.

This is a convenience method that calls shard_model and shard_optimizer_state.

Parameters

partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses rules from the model’s config. Defaults to None.

Returns

A new state object with both model and optimizer states sharded.

Return type

EasyDeLState

shard_with_shape(shape) EasyDeLState[source]#

Applies sharding constraints to the entire state based on a reference shape pytree.

This method takes a pytree shape which has the same structure as the EasyDeLState but contains sharding annotations (e.g., NamedSharding) instead of actual array data. It applies these shardings as constraints to the corresponding arrays in the current state.

Parameters

shape – A pytree with the same structure as self, containing sharding annotations.

Returns

A new state object with sharding constraints applied.

Return type

EasyDeLState

property shardings#

Retrieves the sharding annotations (e.g., NamedSharding) for all components of the EasyDeLState pytree.

Returns

A pytree with the same structure as self, containing sharding annotations or None for components without sharding.

property size: int#

Calculates the total size in bytes of the model parameters (graphstate) and the optimizer state (opt_state).

Returns

The total size in bytes.

Return type

int

step: int | jax.Array#
tx: optax.GradientTransformation#
exception easydel.__init__.EasyDeLSyntaxRuntimeError[source]#

Bases: Exception

exception easydel.__init__.EasyDeLTimerError[source]#

Bases: Exception

class easydel.__init__.ExaoneConfig(vocab_size: int = 102400, hidden_size: int = 2048, intermediate_size: int = 14336, num_layers: int = 32, num_attention_heads: int = 32, num_key_value_heads: int = 8, activation_function='silu', max_position_embeddings=2048, initializer_range=0.02, layer_norm_epsilon=1e-05, use_cache=True, embed_dropout: float = 0.0, pad_token_id: Optional[int] = None, bos_token_id: int = 1, eos_token_id: int = 2, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling: Dict[str, Union[str, float]] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, attention_dropout: float = 0.0, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 102400) – Vocabulary size of the Exaone model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 14336) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • head_dim (int, defaults to 128) – Dimensionality of the head for attention.

  • num_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 8) – Number of key and value heads for each attention layer in the Transformer encoder.

  • activation_function (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • bits (int, optional) – The number of bits to quantize the model to.

  • attention_bias (bool, optional, defaults to False) – Whether to use bias in the attention layer.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, attention_dropout: float = 0.0, rope_scaling: Dict[str, Union[str, float]] = None, attention_bias: bool = False, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the model:

Parameters
  • gradient_checkpointing (str) – Determine whether to use gradient checkpointing

  • use_scan_mlp (bool) – Determine whether to use the scan_mlp function or notn

  • scan_mlp_chunk_size (int) – Chunk the input to the mlp

  • bits (tp.Optional[int]) – Specify the number of bits to use for quantization

  • attention_dropout (float) – Set the dropout rate for the attention layer

  • attention_bias (bool) – when ever to use attention_bias

  • rope_scaling (tp.Dict[str, tp.Union[str, float]]) – rope_scaling for rope

Return type

A tuple of the following

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#

Returns a tuple of parameter names for which weight decay should be excluded.

Returns

An empty tuple, indicating no weight decay exclusions.

Return type

tuple

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size for frequency-based position embeddings.

Returns

The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size for mask-based position embeddings.

Returns

The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.

Return type

int

model_type: str = 'exaone'#
static rng_keys()[source]#

Returns the names of the random number generator keys used by the model.

Returns

A tuple containing “params”, “dropout”, and “fcm” as the RNG keys.

Return type

tuple

class easydel.__init__.ExaoneForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Exaone model with a language modeling head for causal language modeling tasks.

This model extends the base ExaoneModel by adding a linear language modeling head on top of the transformer model. It’s designed for generative tasks and can be used for text generation.

class easydel.__init__.ExaoneForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.ExaoneModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

class easydel.__init__.FalconConfig(vocab_size=65024, hidden_size=4544, num_hidden_layers=32, num_attention_heads=71, num_ln_in_parallel_attn=None, layer_norm_epsilon=1e-05, initializer_range=0.02, use_cache=True, hidden_dropout=0.0, attention_dropout=0.0, num_kv_heads=None, alibi=False, new_decoder_architecture=False, multi_query=True, parallel_attn=True, bias=False, max_position_embeddings=2048, rope_theta=10000.0, rope_scaling=None, bos_token_id=11, eos_token_id=11, ffn_hidden_size=None, ff_factor=None, activation='gelu', gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 65024) – Vocabulary size of the Falcon model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4544) – Dimensionality of the encoder layers and the pooler layer.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 71) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_ln_in_parallel_attn (int, optional) – The number of layer norms in the parallel attention layer.

  • layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • hidden_dropout (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • num_kv_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.

  • alibi (bool, optional) – Whether to use alibi attention.

  • new_decoder_architecture (bool, optional) – Whether to use the new decoder architecture.

  • multi_query (bool, optional, defaults to True) – Whether to use multi-query attention.

  • parallel_attn (bool, optional, defaults to True) – Whether to use parallel attention.

  • bias (bool, optional, defaults to False) – Whether to use bias in the linear layers.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The rope scaling configuration.

  • bos_token_id (int, optional, defaults to 11) – The index of the beginning of sequence token in the vocabulary.

  • eos_token_id (int, optional, defaults to 11) – The index of the end of sequence token in the vocabulary.

  • ffn_hidden_size (int, optional) – Dimensionality of the hidden layer in the FFN

  • ff_factor (int, optional) – The scaling factor of the FFN

  • activation (str, optional, defaults to “gelu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • gradient_checkpointing (str, optional, defaults to “”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Attach custom arguments to the configuration.

Parameters
  • gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.

  • bits (int, optional) – Quantization bits. Defaults to None.

  • **kwargs – Additional keyword arguments.

Returns

The updated configuration instance.

Return type

FalconConfig

attribute_map: dict[str, str] = {'num_attention_heads': 'num_attention_heads', 'num_hidden_layers': 'num_hidden_layers'}#
static get_mesh_names()[source]#

Returns the mesh names used for model parallelism.

Returns

A tuple containing “dp”, “fsdp”, and “tp” as the mesh names.

Return type

tuple

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size for frequency-based position embeddings.

Returns

The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size for mask-based position embeddings.

Returns

The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.

Return type

int

model_type: str = 'falcon'#
property rotary#
class easydel.__init__.FalconForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Falcon model with a language modeling head for causal language modeling tasks.

This model extends the base FalconModel by incorporating a linear language modeling head on top of the base model, designed for generative tasks and text generation. The model can use either alibi positional embeddings or rotary position embeddings (RoPE) based on configuration.

class easydel.__init__.FalconModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.GPT2Config(vocab_size=50257, n_positions=1024, n_embd=768, n_layer=12, n_head=12, n_inner=None, activation_function='gelu_new', resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, layer_norm_epsilon=1e-05, initializer_range=0.02, summary_type='cls_index', summary_use_proj=True, summary_activation=None, summary_proj_to_labels=True, summary_first_dropout=0.1, scale_attn_weights=True, use_cache=True, bos_token_id=50256, eos_token_id=50256, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, tie_word_embeddings: bool = False, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50257) – Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • n_positions (int, optional, defaults to 1024) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • n_embd (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.

  • n_layer (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • n_head (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.

  • n_inner (int, optional) – Dimensionality of the inner feed-forward layers.

  • activation_function (str, optional, defaults to “gelu_new”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • resid_pdrop (float, optional, defaults to 0.1) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • embd_pdrop (float, optional, defaults to 0.1) – The dropout ratio for the embeddings.

  • attn_pdrop (float, optional, defaults to 0.1) – The dropout ratio for the attention probabilities.

  • layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon to use in the layer normalization layers.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • summary_type (str, optional, defaults to “cls_index”) – The summary type to use. Possible values are “cls_index” (equivalent to the output of the last token of the first sentence in a sequence) and “last” (equivalent to the output of the last token in the sequence).

  • summary_use_proj (bool, optional, defaults to True) – Whether to use a projection after the vector extraction.

  • summary_activation (str, optional) – The activation to use for the summary.

  • summary_proj_to_labels (bool, optional, defaults to True) – Whether to project the summary to the labels.

  • summary_first_dropout (float, optional, defaults to 0.1) – The dropout ratio to be used after the projection and activation.

  • scale_attn_weights (bool, optional, defaults to True) – Scale attention weights by dividing by sqrt(hidden_size).

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • bos_token_id (int, optional, defaults to 50256) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 50256) – The id of the end-of-sequence token.

  • scale_attn_by_inverse_layer_idx (bool, optional, defaults to False) – Whether to scale attention weights by 1 / layer_idx + 1.

  • reorder_and_upcast_attn (bool, optional, defaults to False) – Whether to reorder and upcast attention.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • bits (int, optional) – The number of bits to quantize the model to.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Attaches custom arguments to the configuration object.

This method allows adding or overriding configuration attributes dynamically. It iterates through the provided arguments and sets them as attributes of the configuration object if they don’t already exist.

Parameters
  • gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.

  • bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.

  • **kwargs – Additional keyword arguments to attach to the configuration.

attribute_map: dict[str, str] = {'hidden_size': 'n_embd', 'max_position_embeddings': 'n_positions', 'num_attention_heads': 'n_head', 'num_hidden_layers': 'n_layer'}#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.

Parameters
  • *args – Additional positional arguments (unused).

  • **kwargs – Additional keyword arguments (unused).

Returns

A tuple of partition rules, where each rule is a tuple

containing a regex pattern for parameter names and the corresponding PartitionSpec.

Return type

tp.Tuple[tp.Tuple[str, jax.sharding.PartitionSpec]]

keys_to_ignore_at_inference = ['past_key_values']#
model_type: str = 'gpt2'#
class easydel.__init__.GPT2LMHeadModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

GPT-2 model with a language modeling head.

This model extends the base GPT2Model by adding a linear layer on top to predict the next token in a sequence, making it suitable for causal language modeling tasks.

config#

Configuration object for the model.

Type

GPT2Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.__init__.GPT2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

GPT-2 model implementation.

This class implements the main GPT-2 transformer model architecture, consisting of embedding layers (token and position), multiple GPT2Block layers, and a final layer normalization.

config#

Configuration object for the model.

Type

GPT2Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.__init__.GPTJConfig(vocab_size: int = 50400, n_positions: int = 2048, n_embd: int = 4096, n_layer: int = 28, n_head: int = 16, rotary_dim: int = 64, n_inner: int = None, activation_function: str = 'gelu_new', resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, attn_pdrop: float = 0.0, layer_norm_epsilon: float = 1e-05, initializer_range: int = 0.02, use_cache: int = True, bos_token_id: int = 50256, eos_token_id: int = 50256, tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50400) – Vocabulary size of the GPT-J model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • n_positions (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • n_embd (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • n_layer (int, optional, defaults to 28) – Number of hidden layers in the Transformer encoder.

  • n_head (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.

  • rotary_dim (int, optional, defaults to 64) – The dimension of the rotary position embedding.

  • n_inner (int, optional) – Dimensionality of the inner feed-forward layers.

  • activation_function (str, optional, defaults to “gelu_new”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • embd_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.

  • attn_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon to use in the layer normalization layers.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • bos_token_id (int, optional, defaults to 50256) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 50256) – The id of the end-of-sequence token.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • gradient_checkpointing (str, optional, defaults to “”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

attach_custom_arguments(vocab_size: int = 50400, n_positions: int = 2048, n_embd: int = 4096, n_layer: int = 28, n_head: int = 16, rotary_dim: int = 64, n_inner: int = None, activation_function: str = 'gelu_new', resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, attn_pdrop: float = 0.0, layer_norm_epsilon: float = 1e-05, initializer_range: int = 0.02, use_cache: int = True, bos_token_id: int = 50256, eos_token_id: int = 50256, tie_word_embeddings: bool = False, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Attaches custom arguments to the configuration object.

This method allows adding or overriding configuration attributes dynamically. It iterates through a dictionary of basic configuration parameters and sets them as attributes of the configuration object if they don’t already exist.

Parameters
  • vocab_size (int, optional) – Vocabulary size. Defaults to 50400.

  • n_positions (int, optional) – Maximum sequence length. Defaults to 2048.

  • n_embd (int, optional) – Hidden size. Defaults to 4096.

  • n_layer (int, optional) – Number of hidden layers. Defaults to 28.

  • n_head (int, optional) – Number of attention heads. Defaults to 16.

  • rotary_dim (int, optional) – Dimension for rotary position embeddings. Defaults to 64.

  • n_inner (int, optional) – Inner dimension of FFN. Defaults to None.

  • activation_function (str, optional) – Activation function. Defaults to “gelu_new”.

  • resid_pdrop (float, optional) – Residual dropout probability. Defaults to 0.0.

  • embd_pdrop (float, optional) – Embedding dropout probability. Defaults to 0.0.

  • attn_pdrop (float, optional) – Attention dropout probability. Defaults to 0.0.

  • layer_norm_epsilon (float, optional) – Epsilon for layer normalization. Defaults to 1e-5.

  • initializer_range (int, optional) – Initializer range. Defaults to 0.02.

  • use_cache (int, optional) – Whether to use KV cache. Defaults to True.

  • bos_token_id (int, optional) – Beginning-of-sequence token ID. Defaults to 50256.

  • eos_token_id (int, optional) – End-of-sequence token ID. Defaults to 50256.

  • tie_word_embeddings (bool, optional) – Whether to tie input/output embeddings. Defaults to False.

  • bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.

  • gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.

  • **kwargs – Additional keyword arguments.

Returns

The configuration object itself (self).

Return type

GPTJConfig

attribute_map: dict[str, str] = {'hidden_size': 'n_embd', 'max_position_embeddings': 'n_positions', 'num_attention_heads': 'n_head', 'num_hidden_layers': 'n_layer'}#
static get_mesh_names()[source]#

Returns the mesh names used for model parallelism. For GPT-J, it returns mesh names for data parallelism (‘dp’), fully sharded data parallelism (‘fsdp’), sequence parallelism (‘sp’), and tensor parallelism (‘tp’).

Returns

A tuple containing the mesh names (“dp”, “fsdp”, “sp”, “sp”).

Return type

tuple

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.

Parameters
  • *args – Additional positional arguments (unused).

  • **kwargs – Additional keyword arguments (unused).

Returns

A tuple of partition rules, where each rule is a tuple

containing a regex pattern for parameter names and the corresponding PartitionSpec.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'gptj'#
class easydel.__init__.GPTJForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

GPT-J model with a language modeling head.

This model extends the base GPTJModel by adding a linear layer on top to predict the next token in a sequence, making it suitable for causal language modeling tasks.

config#

Configuration object for the model.

Type

GPTJConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.__init__.GPTJModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

GPT-J model implementation.

This class implements the main GPT-J transformer model architecture, consisting of an embedding layer, multiple GPTJBlock layers, and a final layer normalization.

config#

Configuration object for the model.

Type

GPTJConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

class easydel.__init__.GPTNeoXConfig(vocab_size=50432, hidden_size=6144, num_hidden_layers=44, num_attention_heads=64, intermediate_size=24576, hidden_act='gelu', rotary_pct=0.25, rotary_emb_base=10000, attention_dropout=0.0, hidden_dropout=0.0, classifier_dropout=0.1, max_position_embeddings=2048, initializer_range=0.02, layer_norm_eps=1e-05, use_cache=True, bos_token_id=0, eos_token_id=2, tie_word_embeddings=False, use_parallel_residual=True, rope_scaling=None, attention_bias=True, gradient_checkpointing=EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50432) – Vocabulary size of the GPT NeoX model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 6144) – Dimensionality of the encoder layers and the pooler layer.

  • num_hidden_layers (int, optional, defaults to 44) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 64) – Number of attention heads for each attention layer in the Transformer encoder.

  • intermediate_size (int, optional, defaults to 24576) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • hidden_act (str or function, optional, defaults to “gelu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • rotary_pct (float, optional, defaults to 0.25) – The percentage of hidden dimensions to allocate to rotary embeddings.

  • rotary_emb_base (int, optional, defaults to 10000) – The base for the rotary position embedding.

  • classifier_dropout (float, optional, defaults to 0.1) – The dropout ratio for the classifier layer.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • bos_token_id (int, optional, defaults to 0) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • gradient_checkpointing (str, optional, defaults to “everything_saveable”) – The gradient checkpointing configuration.

  • use_parallel_residual (bool, optional, defaults to True) – Whether to use a parallel residual connection in the attention layer.

attach_custom_arguments(**kwargs)[source]#

Attaches custom arguments to the configuration object.

This method allows adding or overriding configuration attributes dynamically. It primarily sets the from_pt attribute to False and ignores other keyword arguments.

Parameters

**kwargs – Additional keyword arguments (ignored).

static get_mesh_names()[source]#

Returns the mesh names used for model parallelism. For GPT-NeoX, it returns mesh names for data parallelism (‘dp’), fully sharded data parallelism (‘fsdp’), tensor parallelism (‘tp’), and sequence parallelism (‘sp’).

Returns

A tuple containing the mesh names (“dp”, “fsdp”, “tp”, “sp”).

Return type

tuple

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.

Parameters
  • *args – Additional positional arguments (unused).

  • **kwargs – Additional keyword arguments (unused).

Returns

A tuple of partition rules, where each rule is a tuple

containing a regex pattern for parameter names and the corresponding PartitionSpec.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'gpt_neox'#
class easydel.__init__.GPTNeoXForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

GPT-NeoX model with a language modeling head.

This model extends the base GPTNeoXModel by adding a linear layer on top to predict the next token in a sequence, making it suitable for causal language modeling tasks.

config#

Configuration object for the model.

Type

GPTNeoXConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.__init__.GPTNeoXModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

GPT-NeoX model implementation.

This class implements the main GPT-NeoX transformer model architecture, consisting of an embedding layer, multiple GPTNeoXBlock layers, and a final layer normalization.

config#

Configuration object for the model.

Type

GPTNeoXConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

class easydel.__init__.GRPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 1e-06, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'GRPOTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: ~typing.Optional[bool] = False, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, max_prompt_length: int = 512, max_completion_length: int = 256, dataset_num_proc: ~typing.Optional[int] = None, beta: float = 0.04, sync_ref_model: bool = False, ref_model_mixup_alpha: float = 0.9, ref_model_sync_steps: int = 64, tools: ~typing.Optional[~typing.List[~typing.Union[dict, ~typing.Callable]]] = None, skip_apply_chat_template: bool = False)[source]#

Bases: TrainingArguments

Configuration class for the GRPOTrainer.

beta: float = 0.04#
dataset_num_proc: Optional[int] = None#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

learning_rate: float = 1e-06#
max_completion_length: int = 256#
max_prompt_length: int = 512#
model_name: str = 'GRPOTrainer'#
ref_model_mixup_alpha: float = 0.9#
ref_model_sync_steps: int = 64#
remove_unused_columns: Optional[bool] = False#
replace(**kwargs)#
skip_apply_chat_template: bool = False#
sync_ref_model: bool = False#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

tools: Optional[List[Union[dict, Callable]]] = None#
class easydel.__init__.GRPOTrainer(arguments: GRPOConfig, vinference: vInference, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]], reward_funcs: Union[EasyDeLBaseModule, EasyDeLState, Callable[[list, list], list[float]], list[Union[easydel.infra.base_module.EasyDeLBaseModule, easydel.infra.base_state.EasyDeLState, Callable[[list, list], list[float]]]]], train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, processing_class: Optional[Any] = None, reward_processing_classes: Optional[Any] = None, data_tokenize_fn: Optional[Callable] = None)[source]#

Bases: Trainer

arguments: GRPOConfig#
configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#

hook process to call in start of the step.

class easydel.__init__.Gemma2Config(vocab_size=256000, hidden_size=3072, intermediate_size=24576, num_hidden_layers=28, num_attention_heads=16, num_key_value_heads=16, head_dim=256, hidden_activation='gelu_pytorch_tanh', max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, tie_word_embeddings=True, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, final_logit_softcapping=30.0, query_pre_attn_scalar=224, sliding_window=4096, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, scan_layers: bool = False, attn_logit_softcapping: Optional[bool] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 256000) – Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 3072) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 24576) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 28) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 16) – Number of key and value heads for each attention layer in the Transformer encoder.

  • head_dim (int, optional, defaults to 256) – Dimensionality of the attention head.

  • hidden_activation (str or function, optional, defaults to “gelu_pytorch_tanh”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 8192) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.

  • eos_token_id (int, optional, defaults to 1) – The index of the end of sequence token in the vocabulary.

  • bos_token_id (int, optional, defaults to 2) – The index of the beginning of sequence token in the vocabulary.

  • tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • attention_bias (bool, optional, defaults to False) – Whether to use attention bias.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • final_logit_softcapping (float, optional, defaults to 30.0) – The soft capping value for the final logits.

  • query_pre_attn_scalar (int, optional, defaults to 224) – The scalar value for the query pre-attention layer.

  • sliding_window (int, optional, defaults to 4096) – The sliding window size.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

  • scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation of the layers.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#

Returns a tuple of parameter names for which weight decay should be excluded.

Returns

An empty tuple, indicating no weight decay exclusions.

Return type

tuple

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size for frequency-based position embeddings.

Returns

The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size for mask-based position embeddings.

Returns

The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.

Return type

int

model_type: str = 'gemma2'#
static rng_keys()[source]#

Returns the names of the random number generator keys used by the model.

Returns

A tuple containing “params”, “dropout”, and “fcm” as the RNG keys.

Return type

tuple

class easydel.__init__.Gemma2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Gemma2 model with a language modeling head for causal language modeling tasks.

This model extends the base Gemma2Model by incorporating a linear language modeling head on top of the base model, designed for generative tasks and text generation.

class easydel.__init__.Gemma2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Gemma2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Gemma3Config(text_config: Optional[Gemma3TextConfig] = None, vision_config: Optional[SiglipVisionConfig] = None, mm_tokens_per_image: int = 256, boi_token_index: int = 255999, eoi_token_index: int = 256000, image_token_index: int = 262144, initializer_range: float = 0.02, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Parameters
  • text_config (Union[Gemma3TextConfig, dict], optional) – The config object of the text backbone.

  • vision_config (Union[AutoConfig, dict], optional) – Custom vision config or dict.

  • mm_tokens_per_image (int, optional, defaults to 256) – The number of tokens per image embedding.

  • boi_token_index (int, optional, defaults to 255999) – The begin-of-image token index to wrap the image prompt.

  • eoi_token_index (int, optional, defaults to 256000) – The end-of-image token index to wrap the image prompt.

  • image_token_index (int, optional, defaults to 262144) – The image token index to encode the image prompt.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

Example:

```python >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig

>>> # Initializing a Siglip-like vision config
>>> vision_config = SiglipVisionConfig()
>>> # Initializing a Gemma3 Text config
>>> text_config = Gemma3TextConfig()
>>> # Initializing a Gemma3 gemma-3-4b style configuration
>>> configuration = Gemma3Config(vision_config, text_config)
>>> # Initializing a model from the gemma-3-4b style configuration
>>> model = Gemma3TextConfig(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model.

Returns

A tuple of tuples, where each inner tuple contains a regex pattern matching parameter names and the corresponding PartitionSpec for sharding those parameters across devices.

Return type

Tuple[Tuple[str, PartitionSpec]]

model_type: str = 'gemma3'#
sub_configs: dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.gemma3.gemma3_configuration.Gemma3TextConfig'>, 'vision_config': <class 'easydel.__init__.modules.siglip.configuration_siglip.SiglipVisionConfig'>}#
class easydel.__init__.Gemma3ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Gemma3 model with a language modeling head for causal language modeling tasks.

This model extends the base Gemma3TextModel by incorporating a linear language modeling head on top of the base model, designed for generative tasks and text generation. The model can optionally apply softcapping to logits based on configuration settings.

class easydel.__init__.Gemma3ForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

get_image_features(pixel_values: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#
loss_type = 'ForCausalLM'#
prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds

Returns

A dictionary of the past_key_values, attention_mask and position ids

update_inputs_for_generation(model_outputs, model_kwargs)[source]#
class easydel.__init__.Gemma3ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Gemma3MultiModalProjector(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.__init__.Gemma3TextConfig(vocab_size=262208, hidden_size=2304, intermediate_size=9216, num_hidden_layers=26, num_attention_heads=8, num_key_value_heads=4, head_dim=256, hidden_activation='gelu_pytorch_tanh', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, tie_word_embeddings=True, rope_theta=1000000.0, attention_bias=False, attention_dropout=0.0, query_pre_attn_scalar=256, sliding_window=4096, final_logit_softcapping=None, attn_logit_softcapping=None, cache_implementation='hybrid', rope_scaling=None, rope_local_base_freq=10000.0, sliding_window_pattern=6, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, scan_layers: bool = False, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information. :param vocab_size: Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the

inputs_ids passed when calling [Gemma3TextModel]

Parameters
  • hidden_size (int, optional, defaults to 2304) – Dimension of the hidden representations.

  • intermediate_size (int, optional, defaults to 9216) – Dimension of the MLP representations.

  • num_hidden_layers (int, optional, defaults to 26) – Number of hidden layers in the Transformer decoder.

  • num_attention_heads (int, optional, defaults to 8) – Number of attention heads for each attention layer in the Transformer decoder.

  • num_key_value_heads (int, optional, defaults to 4) – This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to num_attention_heads.

  • head_dim (int, optional, defaults to 256) – The attention head dimension.

  • hidden_activation (str or function, optional, defaults to “gelu_pytorch_tanh”) – The non-linear activation function (function or string) in the decoder. Will default to “gelu_pytorch_tanh” if not specified. “gelu_pytorch_tanh” uses an approximation of the “gelu” activation function.

  • max_position_embeddings (int, optional, defaults to 131072) – The maximum sequence length that this model might ever be used with.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-06) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 0) – Padding token id.

  • eos_token_id (int, optional, defaults to 1) – End of stream token id.

  • bos_token_id (int, optional, defaults to 2) – Beginning of stream token id.

  • tie_word_embeddings (bool, optional, defaults to True) – Whether to tie weight embeddings

  • rope_theta (float, optional, defaults to 1000000.0) – The base period of the RoPE embeddings.

  • attention_bias (bool, defaults to False, optional, defaults to False) – Whether to use a bias in the query, key, value and output projection layers during self-attention.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • query_pre_attn_scalar (float, optional, defaults to 256) – Scaling factor used on the attention scores

  • sliding_window (int, optional, defaults to 4096) – in Gemma3Text, every other layer uses sliding window attention. This is the size of the sliding window.

  • final_logit_softcapping (float, optional) – Scaling factor when applying tanh softcapping on the logits.

  • attn_logit_softcapping (float, optional) – Scaling factor when applying tanh softcapping on the attention scores.

  • cache_implementation (str, optional, defaults to “hybrid”) – the cache type to be used with generate.

  • rope_scaling (Dict, optional) –

    Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type and you expect the model to work on longer max_position_embeddings, we recommend you to update this value accordingly. Expected contents:

    rope_type (str):

    The sub-variant of RoPE to use. Can be one of [‘default’, ‘linear’, ‘dynamic’, ‘yarn’, ‘longrope’, ‘llama3’], with ‘default’ being the original RoPE implementation.

    factor (float, optional):

    Used with all rope types except ‘default’. The scaling factor to apply to the RoPE embeddings. In most scaling types, a factor of x will enable the model to handle sequences of length x * original maximum pre-trained length.

    original_max_position_embeddings (int, optional):

    Used with ‘dynamic’, ‘longrope’ and ‘llama3’. The original max position embeddings used during pretraining.

    attention_factor (float, optional):

    Used with ‘yarn’ and ‘longrope’. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the factor field to infer the suggested value.

    beta_fast (float, optional):

    Only used with ‘yarn’. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32.

    beta_slow (float, optional):

    Only used with ‘yarn’. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1.

    short_factor (List[float], optional):

    Only used with ‘longrope’. The scaling factor to be applied to short contexts (< original_max_position_embeddings). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2

    long_factor (List[float], optional):

    Only used with ‘longrope’. The scaling factor to be applied to long contexts (< original_max_position_embeddings). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2

    low_freq_factor (float, optional):

    Only used with ‘llama3’. Scaling factor applied to low frequency components of the RoPE

    high_freq_factor (float, optional):

    Only used with ‘llama3’. Scaling factor applied to high frequency components of the RoPE

  • rope_local_base_freq (float, optional, defaults to 10000.0) – The base period of the RoPE embeddings for local attention.

  • sliding_window_pattern – Pattern for the sliding window attention.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'gemma3_text'#
static rng_keys()[source]#
class easydel.__init__.Gemma3TextModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property default_frequencies#
class easydel.__init__.GemmaConfig(vocab_size=256000, hidden_size=3072, intermediate_size=24576, num_hidden_layers=28, num_attention_heads=16, num_key_value_heads=16, head_dim=256, hidden_act='gelu_pytorch_tanh', max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, tie_word_embeddings=True, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, scan_layers: bool = False, hidden_activation='gelu_pytorch_tanh', **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 256000) – Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 3072) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 24576) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 28) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 16) – Number of key and value heads for each attention layer in the Transformer encoder.

  • head_dim (int, optional, defaults to 256) – Dimensionality of the attention head.

  • hidden_act (str or function, optional, defaults to “gelu_pytorch_tanh”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 8192) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.

  • eos_token_id (int, optional, defaults to 1) – The index of the end of sequence token in the vocabulary.

  • bos_token_id (int, optional, defaults to 2) – The index of the beginning of sequence token in the vocabulary.

  • tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • attention_bias (bool, optional, defaults to False) – Whether to use attention bias.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

  • scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation of the layers.

  • hidden_activation (str, optional) – The hidden activation function to use.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#

Returns a tuple of parameter names for which weight decay should be excluded.

Returns

An empty tuple, indicating no weight decay exclusions.

Return type

tuple

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size for frequency-based position embeddings.

Returns

The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size for mask-based position embeddings.

Returns

The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.

Return type

int

model_type: str = 'gemma'#
static rng_keys()[source]#

Returns the names of the random number generator keys used by the model.

Returns

A tuple containing “params”, “dropout”, and “fcm” as the RNG keys.

Return type

tuple

class easydel.__init__.GemmaForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.GemmaForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.GemmaModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Grok1Config(vocab_size=32000, hidden_size=4096, intermediate_size=32768, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, attn_output_multiplier=1.0, max_attn_value=1.0, max_position_embeddings=4096, embedding_multiplier_scale: float = 1.0, output_multiplier_scale: float = 1.0, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=True, num_experts_per_tok=2, num_experts=8, output_router_logits=False, router_aux_loss_coef=0.001, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the Grok-1 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 32768) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 32) – Number of key and value heads for each attention layer in the Transformer encoder.

  • attn_output_multiplier (float, optional, defaults to 1.0) – The multiplier value applied to the attention output.

  • max_attn_value (float, optional, defaults to 1.0) – The maximum value of the attention weights.

  • max_position_embeddings (int, optional, defaults to 4096) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • embedding_multiplier_scale (float, optional, defaults to 1.0) – The scale factor for the embedding layer.

  • output_multiplier_scale (float, optional, defaults to 1.0) – The scale factor for the output layer.

  • rms_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 1) – The index of the beginning of sequence token in the vocabulary.

  • eos_token_id (int, optional, defaults to 2) – The index of the end of sequence token in the vocabulary.

  • tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.

  • num_experts_per_tok (int, optional, defaults to 2) – The number of experts per token.

  • num_experts (int, optional, defaults to 8) – The number of experts.

  • output_router_logits (bool, optional, defaults to False) – Whether to output router logits.

  • router_aux_loss_coef (float, optional, defaults to 0.001) – The router auxiliary loss coefficient.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

attach_custom_arguments(tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Attaches custom arguments to the configuration object.

This method allows adding or overriding configuration attributes dynamically. It primarily sets attributes related to word embeddings, gradient checkpointing, and quantization bits.

Parameters
  • tie_word_embeddings (bool, optional) – Whether to tie input/output embeddings. Defaults to False.

  • gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.

  • bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.

  • **kwargs – Additional keyword arguments (ignored).

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.

Parameters
  • *args – Additional positional arguments (unused).

  • **kwargs – Additional keyword arguments (unused).

Returns

A tuple of partition rules, where each rule is a tuple

containing a regex pattern for parameter names and the corresponding PartitionSpec.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#

Returns a tuple of parameter names for which weight decay should be excluded.

Returns

An empty tuple, indicating no specific weight decay exclusions for this model.

Return type

tuple

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size specifically for frequency-based position embeddings.

If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for frequency encoding.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size specifically for mask-based position embeddings.

If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for mask encoding.

Return type

int

model_type: str = 'grok-1'#
static rng_keys()[source]#

Returns the names of the random number generator keys used by the model.

Returns

A tuple containing “params” and “dropout” as the RNG keys.

Return type

tuple

class easydel.__init__.Grok1ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Grok-1 model with a language modeling head.

This model extends the base Grok1Model by adding a linear layer on top to predict the next token in a sequence, making it suitable for causal language modeling tasks. It also includes handling for the Mixture of Experts auxiliary loss.

config#

Configuration object for the model.

Type

Grok1Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.__init__.Grok1Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Grok-1 model implementation.

This class implements the main Grok-1 transformer model architecture, consisting of an embedding layer, multiple Grok1DecoderLayer layers (with sparse MoE), and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Grok1Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

class easydel.__init__.InternLM2Config(vocab_size=103168, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=0, bos_token_id=1, eos_token_id=2, pretraining_tp=1, tie_word_embeddings=False, bias=True, rope_theta=10000, rope_scaling=None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = -1, fcm_max_ratio: float = -1, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, scan_layers: bool = False, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 11008) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to number_rep_kv * num_attention_heads if not set.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 0) – The id of the pad token.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • bias (bool, optional, defaults to False) – Whether to use attention bias.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • fcm_min_ratio (float, optional, defaults to -1) – The minimum ratio for Flash Attention.

  • fcm_max_ratio (float, optional, defaults to -1) – The maximum ratio for Flash Attention.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • bits (int, optional) – The number of bits to quantize the model to.

  • hidden_act (str, optional, defaults to “silu”) – The hidden activation function to use.

  • pretraining_tp (int, optional, defaults to 1) – The tensor parallelism degree used during pretraining.

  • mlp_bias (bool, optional, defaults to False) – Whether to use bias in the MLP.

  • scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation for the layers.

attach_custom_arguments(tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = 0.0, fcm_max_ratio: float = 0.0, bits: Optional[int] = None, rope_theta: float = 10000.0, hidden_act: str = 'silu', scan_layers: bool = True, **kwargs)[source]#

Attaches custom arguments to the configuration object.

This method allows adding or overriding configuration attributes dynamically. It iterates through the provided arguments and sets them as attributes of the configuration object.

Parameters
  • tie_word_embeddings (bool, optional) – Whether to tie input/output embeddings. Defaults to False.

  • gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.

  • fcm_min_ratio (float, optional) – Minimum ratio for Flash Attention. Defaults to 0.0.

  • fcm_max_ratio (float, optional) – Maximum ratio for Flash Attention. Defaults to 0.0.

  • bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.

  • rope_theta (float, optional) – Base value for RoPE. Defaults to 10000.0.

  • hidden_act (str, optional) – Activation function. Defaults to “silu”.

  • scan_layers (bool, optional) – Whether to use scan layers. Defaults to True.

  • **kwargs – Additional keyword arguments (ignored in this implementation).

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.

Parameters
  • *args – Additional positional arguments (unused).

  • **kwargs – Additional keyword arguments (unused).

Returns

A tuple of partition rules, where each rule is a tuple

containing a regex pattern for parameter names and the corresponding PartitionSpec.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#

Returns a tuple of parameter names for which weight decay should be excluded.

Returns

An empty tuple, indicating no specific weight decay exclusions for this model.

Return type

tuple

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size specifically for frequency-based position embeddings.

If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for frequency encoding.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size specifically for mask-based position embeddings.

If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for mask encoding.

Return type

int

model_type: str = 'internlm2'#
static rng_keys()[source]#

Returns the names of the random number generator keys used by the model.

Returns

A tuple containing “params”, “dropout”, and “fcm” as the RNG keys.

Return type

tuple

class easydel.__init__.InternLM2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

InternLM2 model with a Causal Language Modeling head.

This model consists of the base InternLM2 transformer (InternLM2Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction.

config#

Configuration object for the model.

Type

InternLM2Config

dtype#

Data type for computation. Default is jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Default is jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Default is None.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

module#

The core InternLM2 transformer model.

Type

InternLM2Model

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

class easydel.__init__.InternLM2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

InternLM2 model with a Sequence Classification head.

This model consists of the base InternLM2 transformer (InternLM2Model) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the first token) to the number of classes for classification.

config#

Configuration object for the model.

Type

InternLM2Config

dtype#

Data type for computation. Default is jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Default is jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Default is None.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

module#

The core InternLM2 transformer model.

Type

InternLM2Model

score#

The linear layer for classification.

Type

ParallelLinear

class easydel.__init__.InternLM2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base InternLM2 model transformer.

This class represents the core transformer architecture of the InternLM2 model, consisting of embedding layers, multiple transformer blocks, and a final layer normalization.

config#

Configuration object for the model.

Type

InternLM2Config

dtype#

Data type for computation. Default is jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Default is jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Default is None.

Type

jax.lax.PrecisionLike

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

Sequence of transformer blocks.

Type

tp.Sequence[InternLM2Block]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

scan_layers#

Whether to use JAX scan for layer processing.

Type

bool

blocks_class#

The class used for the transformer blocks.

Type

InternLM2Block

class easydel.__init__.JaxDistributedConfig[source]#

Bases: object

From EasyLM Utility class for initializing JAX distributed.

static get_default_config(updates=None)[source]#
classmethod initialize(config=None)[source]#
class easydel.__init__.Llama4Config(vision_config=None, text_config=None, boi_token_index=200080, eoi_token_index=200081, image_token_index=200092, tie_word_embeddings=False, **kwargs)[source]#

Bases: EasyDeLBaseConfig

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the Llama4 model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'llama4'#
sub_configs: dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.llama4.llama4_configuration.Llama4TextConfig'>, 'vision_config': <class 'easydel.__init__.modules.llama4.llama4_configuration.Llama4VisionConfig'>}#
class easydel.__init__.Llama4ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Llama4ForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Llama4Vision model for conditional text generation based on image inputs. Combines a vision tower and a language model with a multi-modal projector.

config#

Configuration object.

Type

Llama4VisionConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

get_image_features(pixel_values: Union[Array, ndarray, bool, number], **kwargs) Union[Array, ndarray, bool, number][source]#

Extracts and projects image features from the vision tower.

Parameters

pixel_values (chex.Array) – Input pixel values for the images.

Returns

Processed image features ready for the language model.

Return type

chex.Array

loss_type = 'ForCausalLM'#
prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Prepares inputs for text generation, including pixel values if provided.

Parameters
  • input_ids (chex.Array) – Initial input token IDs.

  • max_length (int) – Maximum generation length.

  • pixel_values (Optional[chex.Array]) – Pixel values for image input.

  • attention_mask (Optional[chex.Array]) – Attention mask.

Returns

Model inputs ready for generation.

Return type

dict

update_inputs_for_generation(model_outputs, model_kwargs)[source]#

Updates model inputs for the next step of generation, removing pixel values after the first step.

Parameters
  • model_outputs – Outputs from the previous generation step.

  • model_kwargs – Current keyword arguments for the model.

Returns

Updated model keyword arguments.

Return type

dict

class easydel.__init__.Llama4ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Llama model for sequence classification tasks.

This class extends the base Llama model by adding a linear classification head to perform sequence classification tasks such as sentiment analysis or text classification.

config#

Configuration for the model.

Type

LlamaConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

class easydel.__init__.Llama4TextConfig(vocab_size=202048, hidden_size=5120, intermediate_size=8192, intermediate_size_mlp=16384, num_hidden_layers=48, num_attention_heads=40, num_key_value_heads=8, head_dim=128, hidden_act='silu', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=500000, attention_dropout=0.0, num_experts_per_tok=1, num_local_experts=16, moe_layers=None, interleave_moe_layer_step=1, use_qk_norm=True, output_router_logits=False, router_aux_loss_coef=0.001, router_jitter_noise=0.0, rope_scaling=None, no_rope_layers=None, no_rope_layer_interval=4, attention_chunk_size=8192, attn_temperature_tuning=4, floor_scale=8192, attn_scale=0.1, **kwargs)[source]#

Bases: EasyDeLBaseConfig

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the Llama4Text model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'llama4_text'#
class easydel.__init__.Llama4TextModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Llama4VisionConfig(hidden_size: int = 768, hidden_act: str = 'gelu', num_hidden_layers: int = 34, num_attention_heads: int = 16, num_channels: int = 3, intermediate_size: int = 5632, vision_output_dim: int = 7680, image_size: int = 448, patch_size: int = 14, norm_eps: float = 1e-05, vision_feature_layer=-1, vision_feature_select_strategy='default', initializer_range: float = 0.02, pixel_shuffle_ratio=0.5, projector_input_dim=4096, projector_output_dim=4096, multi_modal_projector_bias=False, projector_dropout=0.0, attention_dropout=0.0, rope_theta=10000, **kwargs)[source]#

Bases: EasyDeLBaseConfig

base_config_key: str = 'vision_config'#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the Llama4Vision model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'llama4_vision_model'#
class easydel.__init__.Llama4VisionModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

class easydel.__init__.LlamaConfig(vocab_size: int = 32000, hidden_size: int = 4096, intermediate_size: int = 11008, num_hidden_layers: int = 32, num_attention_heads: int = 32, number_rep_kv: int = 1, head_dim: Optional[int] = None, num_key_value_heads: Optional[int] = None, max_position_embeddings: int = 2048, rms_norm_eps: float = 1e-06, initializer_range: float = 0.02, use_cache: bool = True, bos_token_id: int = 0, eos_token_id: int = 1, resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, attention_dropout: float = 0.0, rope_theta: float = 10000.0, attention_bias: bool = False, tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = -1, fcm_max_ratio: float = -1, rope_scaling: Dict[str, Union[str, float]] = None, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, hidden_act: str = 'silu', pretraining_tp: int = 1, mlp_bias: bool = False, scan_layers: bool = False, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the Llama model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 11008) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • number_rep_kv (int, optional, defaults to 1) – Number of repetitions for the key and value vectors.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to number_rep_kv * num_attention_heads if not set.

  • max_position_embeddings (int, optional, defaults to 2048) –

    The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

    head_dim (int, optional):

    head_dim for attention qkv.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • bos_token_id (int, optional, defaults to 0) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 1) – The id of the end-of-sequence token.

  • resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • embd_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • attention_bias (bool, optional, defaults to False) – Whether to use attention bias.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • fcm_min_ratio (float, optional, defaults to -1) – The minimum ratio for Flash Attention.

  • fcm_max_ratio (float, optional, defaults to -1) – The maximum ratio for Flash Attention.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • bits (int, optional) – The number of bits to quantize the model to.

  • hidden_act (str, optional, defaults to “silu”) – The hidden activation function to use.

  • pretraining_tp (int, optional, defaults to 1) – The tensor parallelism degree used during pretraining.

  • mlp_bias (bool, optional, defaults to False) – Whether to use bias in the MLP.

  • scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation for the layers.

attach_custom_arguments(resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, attention_dropout: float = 0.0, tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = 0.0, fcm_max_ratio: float = 0.0, number_rep_kv: int = 1, bits: Optional[int] = None, rope_theta: float = 10000.0, attention_bias: bool = False, hidden_act: str = 'silu', scan_layers: bool = True, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • resid_pdrop – float: Set the dropout rate for residual connections

  • embd_pdrop – float: Set the probability of dropping an embedding

  • attention_dropout – float: Set the probability of dropping out the attention layer

  • tie_word_embeddings – bool: Tie the word embeddings to the decoder

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • fcm_min_ratio – float: Control the minimum ratio of the number of chunks to be used in flash-based computation

  • fcm_max_ratio – float: Set the maximum ratio of the number of input tokens to output tokens

  • number_rep_kv – int: Determine how many times the key and value vectors are repeated

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

  • rope_theta – float : rope_theta for compute rope

  • attention_bias – bool : whenever to use attention bias or no

  • hidden_act – str : hidden_act for mlp

  • scan_layers – bool: Determine whether to use scan layers or not

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'llama'#
static rng_keys()[source]#
class easydel.__init__.LlamaForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Llama model with a language modeling head for causal language modeling tasks.

This model is a transformer-based language model with causal attention masks applied to perform autoregressive language generation.

config#

Configuration for the model.

Type

LlamaConfig

dtype#

Data type for computations (default is jnp.float32).

Type

jnp.dtype

param_dtype#

Data type for parameters (default is jnp.float32).

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

tp.Optional[tp.Union[str, jax.lax.Precision]]

class easydel.__init__.LlamaForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Llama model for sequence classification tasks.

This class extends the base Llama model by adding a linear classification head to perform sequence classification tasks such as sentiment analysis or text classification.

config#

Configuration for the model.

Type

LlamaConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

class easydel.__init__.LlamaModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Llama model implementation.

This implements the Llama language model architecture, utilizing transformer blocks with RMSNorm, rotary position embeddings, and a specific attention mechanism.

config#

Configuration for the model.

Type

LlamaConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

class easydel.__init__.LlavaConfig(vision_config=None, text_config=None, image_token_index=32000, projector_hidden_act='gelu', vision_feature_select_strategy='default', vision_feature_layer=-2, image_seq_length=576, multimodal_projector_bias=True, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the configuration of a [LlavaForConditionalGeneration]. It is used to instantiate an Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Llava-9B.

e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)

Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.

Parameters
  • vision_config (Union[AutoConfig, dict], optional, defaults to CLIPVisionConfig) – The config object or dictionary of the vision backbone.

  • text_config (Union[AutoConfig, dict], optional, defaults to LlamaConfig) – The config object or dictionary of the text backbone.

  • image_token_index (int, optional, defaults to 32000) – The image token index to encode the image prompt.

  • projector_hidden_act (str, optional, defaults to “gelu”) – The activation function used by the multimodal projector.

  • vision_feature_select_strategy (str, optional, defaults to “default”) – The feature selection strategy used to select the vision feature from the vision backbone. Can be one of “default” or “full”.

  • vision_feature_layer (Union[int, List[int]], optional, defaults to -2) – The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features.

  • image_seq_length (int, optional, defaults to 576) – Sequence length of one image embedding.

  • multimodal_projector_bias (bool, optional, defaults to True) – Whether to use bias in the multimodal projector.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for distributed training by combining the partition rules from both the text and vision configurations.

This method retrieves the partition rules from the text_config and vision_config components and combines them to create a comprehensive set of rules for the entire multimodal model.

Parameters
  • *args – Variable length argument list to be passed to the text and vision configs.

  • **kwargs – Arbitrary keyword arguments to be passed to the text and vision configs.

Returns

A combined tuple of partition rules from both text and vision configurations.

Return type

tuple

model_type: str = 'llava'#
sub_configs: dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.auto.auto_configuration.AutoEasyDeLConfig'>, 'vision_config': <class 'easydel.__init__.modules.auto.auto_configuration.AutoEasyDeLConfig'>}#
class easydel.__init__.LlavaForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

get_image_features(pixel_values: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#
loss_type = 'ForCausalLM'#
prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds

Returns

A dictionary of the past_key_values, attention_mask and position ids

update_inputs_for_generation(model_outputs, model_kwargs)[source]#
class easydel.__init__.LossConfig(ignore_index: int = -100, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]] = 'NUM_REAL_TARGET_TOKENS', num_labels: Optional[str] = None, problem_type: Optional[str] = None, divide_weight_sum: bool = False, shift_tokens: bool = True, break_on_nan: bool = True, reduction: Optional[Literal['none', 'mean', 'sum']] = None, num_classification_labels: Optional[int] = None, classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None)[source]#

Bases: object

Configuration class for customizing loss computation behavior.

ignore_index#

Specifies a target value that is ignored and does not contribute to the loss. Defaults to -100.

Type

int

label_smoothing#

Amount of label smoothing to apply. 0.0 means no smoothing. Defaults to 0.0.

Type

float

z_loss#

Coefficient for the z-loss regularization term, which encourages logits for non-target classes to be small. Defaults to 0.0.

Type

float

loss_normalizing_factor#

How to normalize the loss. Can be a constant float/int, a string representation of a SpecialLossNormalizingFactor enum, or the enum itself. Defaults to “NUM_REAL_TARGET_TOKENS”.

Type

FACTOR_TYPE

num_labels#

The number of labels for classification tasks. Used in ForSequenceClassificationLoss. Defaults to None.

Type

tp.Optional[int]

problem_type#

Specifies the problem type for sequence classification (e.g., “single_label_classification”, “multi_label_classification”). Defaults to None.

Type

tp.Optional[str]

divide_weight_sum#

If True, divides the loss by the sum of weights, in addition to the loss_normalizing_factor. Defaults to False.

Type

bool

shift_tokens#

If True (typically for Causal LM), shifts the logits and labels so that the model predicts the next token. Defaults to True.

Type

bool

break_on_nan#

If True, raises an EasyDeLBreakRequest if a NaN is encountered during loss computation. Defaults to True.

Type

bool

reduction#

Specifies the reduction to apply to the loss. If None, the default reduction of the specific loss function is used. Defaults to None.

Type

tp.Optional[tp.Literal[“none”, “mean”, “sum”]]

num_classification_labels#

Number of labels specifically for sequence classification. Alias for num_labels. Defaults to None.

Type

tp.Optional[int]

classification_problem_type#

Problem type specifically for sequence classification. Alias for problem_type. Defaults to None.

Type

tp.Optional[tp.Literal[“regression”, “single_label_classification”, “multi_label_classification”]]

break_on_nan: bool = True#
classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None#
divide_weight_sum: bool = False#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

ignore_index: int = -100#
label_smoothing: float = 0.0#
loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]] = 'NUM_REAL_TARGET_TOKENS'#
num_classification_labels: Optional[int] = None#
num_labels: Optional[str] = None#
problem_type: Optional[str] = None#
reduction: Optional[Literal['none', 'mean', 'sum']] = None#
replace(**kwargs)#
shift_tokens: bool = True#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

z_loss: float = 0.0#
class easydel.__init__.Mamba2Config(num_heads=128, head_dim=64, vocab_size=32768, hidden_size=4096, state_size=128, num_hidden_layers=64, layer_norm_epsilon=1e-05, pad_token_id=1, bos_token_id=0, eos_token_id=2, expand=2, conv_kernel=4, n_groups=8, use_bias=False, use_conv_bias=True, hidden_act='silu', initializer_range=0.1, residual_in_fp32=True, time_step_rank='auto', time_step_min=0.001, time_step_max=0.1, time_step_floor=0.0001, time_step_limit=(0.0, inf), rescale_prenorm_residual=False, use_cache=True, norm_before_gate=True, rms_norm=True, chunk_size=256, tie_word_embeddings=False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32768) – Vocabulary size of the Mamba2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • state_size (int, optional, defaults to 128) – State size of the Mamba2 model.

  • num_hidden_layers (int, optional, defaults to 64) – Number of hidden layers in the Mamba2 encoder.

  • num_heads (int, optional, defaults to 128) – Number of attention heads for the grouped selective scan.

  • head_dim (int, optional, defaults to 64) – Dimension of each attention head.

  • n_groups (int, optional, defaults to 8) – Number of groups for the grouped selective scan.

  • layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • pad_token_id (int, optional, defaults to 1) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 0) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • expand (int, optional, defaults to 2) – Expansion factor for the intermediate size.

  • conv_kernel (int, optional, defaults to 4) – Kernel size of the convolution layer.

  • use_bias (bool, optional, defaults to False) – Whether to use bias in the linear layers.

  • use_conv_bias (bool, optional, defaults to True) – Whether to use bias in the convolution layer.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • initializer_range (float, optional, defaults to 0.1) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • residual_in_fp32 (bool, optional, defaults to True) – Whether to compute the residual connection in float32.

  • time_step_rank (str or int, optional, defaults to “auto”) – The rank of the time step embedding. If set to “auto”, the rank is calculated as math.ceil(self.hidden_size / 16).

  • time_step_min (float, optional, defaults to 0.001) – The minimum value for the time step embedding.

  • time_step_max (float, optional, defaults to 0.1) – The maximum value for the time step embedding.

  • time_step_floor (float, optional, defaults to 1e-4) – The floor value for the time step embedding.

  • time_step_limit (tuple, optional, defaults to (0.0, float(“inf”))) – The minimum and maximum limits for the time step.

  • rescale_prenorm_residual (bool, optional, defaults to False) – Whether to rescale the pre-norm residual.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • norm_before_gate (bool, optional, defaults to True) – Whether to apply normalization before the gate activation.

  • rms_norm (bool, optional, defaults to True) – Whether to use root mean square normalization.

  • chunk_size (int, optional, defaults to 256) – Size of chunks for processing long sequences.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the word embedding weights with the output projection weights.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for distributing the Mamba2 model parameters across multiple devices.

These rules define how parameters should be partitioned when using techniques like Fully Sharded Data Parallelism (FSDP), Sharded Parallelism (SP), and Tensor Parallelism (TP). Each rule consists of a regex pattern matching parameter names and a corresponding PartitionSpec.

Returns

A tuple of tuples where each inner tuple contains:
  • A regex pattern matching parameter names

  • A PartitionSpec object specifying how to partition matching parameters

Return type

tuple

model_type: str = 'mamba2'#
class easydel.__init__.Mamba2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

init_cache(batch_size: int, max_length: int)[source]#
prepare_inputs_for_generation(input_ids, inputs_embeds=None, cache_params: Optional[Mamba2Cache] = None, cache_position: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, **kwargs)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds

Returns

A dictionary of the past_key_values, attention_mask and position ids

update_inputs_for_generation(model_outputs, model_kwargs)[source]#
class easydel.__init__.Mamba2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.MambaConfig(vocab_size=50280, hidden_size=768, state_size=16, num_hidden_layers=32, layer_norm_epsilon=1e-05, pad_token_id=0, bos_token_id=0, eos_token_id=0, expand=2, conv_kernel=4, use_bias=False, use_conv_bias=True, hidden_act='silu', initializer_range=0.1, residual_in_fp32=True, time_step_rank='auto', time_step_scale=1.0, time_step_min=0.001, time_step_max=0.1, time_step_init_scheme='random', time_step_floor=0.0001, rescale_prenorm_residual=False, use_cache=True, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_mambapy: bool = False, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50280) – Vocabulary size of the Mamba model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.

  • state_size (int, optional, defaults to 16) – State size of the Mamba model.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 0) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 0) – The id of the end-of-sequence token.

  • expand (int, optional, defaults to 2) – Expansion factor for the intermediate size.

  • conv_kernel (int, optional, defaults to 4) – Kernel size of the convolution layer.

  • use_bias (bool, optional, defaults to False) – Whether to use bias in the linear layers.

  • use_conv_bias (bool, optional, defaults to True) – Whether to use bias in the convolution layer.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • initializer_range (float, optional, defaults to 0.1) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • residual_in_fp32 (bool, optional, defaults to True) – Whether to compute the residual connection in float32.

  • time_step_rank (str or int, optional, defaults to “auto”) – The rank of the time step embedding. If set to “auto”, the rank is calculated as math.ceil(self.hidden_size / 16).

  • time_step_scale (float, optional, defaults to 1.0) – The scale factor for the time step embedding.

  • time_step_min (float, optional, defaults to 0.001) – The minimum value for the time step embedding.

  • time_step_max (float, optional, defaults to 0.1) – The maximum value for the time step embedding.

  • time_step_init_scheme (str, optional, defaults to “random”) – The initialization scheme for the time step embedding. Possible values are “random” and “uniform”.

  • time_step_floor (float, optional, defaults to 1e-4) – The floor value for the time step embedding.

  • rescale_prenorm_residual (bool, optional, defaults to False) – Whether to rescale the pre-norm residual.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for distributing the Mamba model parameters across multiple devices.

These rules define how parameters should be partitioned when using techniques like Fully Sharded Data Parallelism (FSDP), Sharded Parallelism (SP), and Tensor Parallelism (TP). Each rule consists of a regex pattern matching parameter names and a corresponding PartitionSpec.

Returns

A tuple of tuples where each inner tuple contains:
  • A regex pattern matching parameter names

  • A PartitionSpec object specifying how to partition matching parameters

Return type

tuple

model_type: str = 'mamba'#
class easydel.__init__.MambaForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

init_cache(batch_size: int, max_length: int)[source]#
prepare_inputs_for_generation(input_ids, max_length, **kwargs)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds

Returns

A dictionary of the past_key_values, attention_mask and position ids

update_inputs_for_generation(outputs: MambaOutput, model_kwargs: Dict[str, Any], **kwargs) Dict[str, Any][source]#
class easydel.__init__.MambaModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

init_cache(batch_size: int, max_length: int)[source]#
class easydel.__init__.MistralConfig(vocab_size: int = 32000, hidden_size: int = 4096, intermediate_size: int = 14336, head_dim: int = 128, num_hidden_layers: int = 32, num_attention_heads: int = 32, num_key_value_heads: int = 8, hidden_act='silu', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=None, bos_token_id: int = 1, eos_token_id: int = 2, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling: Dict[str, Union[str, float]] = None, sliding_window=4096, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, number_rep_kv: int = 1, attention_dropout: float = 0.0, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, attention_bias: bool = False, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 14336) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • head_dim (int, defaults to 128) – Dimensionality of the head for attention.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 8) – Number of key and value heads for each attention layer in the Transformer encoder.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 4096 * 32) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • sliding_window (int, optional, defaults to 4096) – The sliding window size.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • number_rep_kv (int, optional, defaults to 1) – Number of repetitions for the key and value vectors.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • bits (int, optional) – The number of bits to quantize the model to.

  • attention_bias (bool, optional, defaults to False) – Whether to use bias in the attention layer.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, attention_dropout: float = 0.0, rope_scaling: Dict[str, Union[str, float]] = None, attention_bias: bool = False, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the model:

Parameters
  • self – Bind the attributes and methods of a class to an instance of that class

  • gradient_checkpointing – str: Determine whether to use gradient checkpointing

  • use_scan_mlp – bool: Determine whether to use the scan_mlp function or notn

  • scan_mlp_chunk_size – int: Chunk the input to the mlp

  • number_rep_kv – int: Control the number of times that the key and value vectors are repeated

  • bits – tp.Optional[int]: Specify the number of bits to use for quantization

  • attention_dropout – float: Set the dropout rate for the attention layer

  • attention_bias – bool: when ever to use attention_bias

  • rope_scaling – tp.Dict[str, tp.Union[str, float]]: rope_scaling for rope

Return type

A tuple of the following

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'mistral'#
static rng_keys()[source]#
class easydel.__init__.MistralForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Mistral model with a language modeling head for causal language modeling tasks.

This model is a transformer-based language model with sliding window attention applied to perform autoregressive language generation.

config#

Configuration for the model.

Type

MistralConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

class easydel.__init__.MistralForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Mistral model for sequence classification tasks.

This class extends the base Mistral model by adding a linear classification head to perform sequence classification tasks such as sentiment analysis or text classification.

config#

Configuration for the model.

Type

MistralConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

class easydel.__init__.MistralModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Mistral model implementation.

This implements the Mistral language model architecture, utilizing transformer blocks with RMSNorm, sliding window attention, and rotary position embeddings.

config#

Configuration for the model.

Type

MistralConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

class easydel.__init__.MixtralConfig(vocab_size=32000, hidden_size=4096, intermediate_size=14336, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, hidden_act='silu', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=1000000.0, sliding_window=4096, attention_dropout=0.0, num_experts_per_tok=2, num_local_experts=8, output_router_logits=False, router_aux_loss_coef=0.001, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, rope_scaling: Dict[str, Union[str, float]] = None, attention_bias: bool = False, initialization_of_moe: bool = False, router_jitter_noise=0.0, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the Mixtral model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 14336) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 8) – Number of key and value heads for each attention layer in the Transformer encoder.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 4096 * 32) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 1e6) – The theta value to use for rotary position embeddings.

  • sliding_window (int, optional, defaults to 4096) – The sliding window size.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • num_experts_per_tok (int, optional, defaults to 2) – The number of experts per token.

  • num_local_experts (int, optional, defaults to 8) – The number of local experts.

  • output_router_logits (bool, optional, defaults to False) – Whether to output router logits.

  • router_aux_loss_coef (float, optional, defaults to 0.001) – The router auxiliary loss coefficient.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • number_rep_kv (int, optional, defaults to 1) – Number of repetitions for the key and value vectors.

  • bits (int, optional) – The number of bits to quantize the model to.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • attention_bias (bool, optional, defaults to False) – Whether to use bias in the attention layer.

  • initialization_of_moe (bool, optional, defaults to False) – Whether to initialize the MoE layers.

  • router_jitter_noise (float, optional, defaults to 0.0) – The jitter noise for the router.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, attention_dropout: float = 0.0, rope_scaling: Dict[str, Union[str, float]] = None, attention_bias: bool = False, initialization_of_moe: bool = False, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the model:

Parameters
  • gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.

  • use_scan_mlp (bool, optional) – Whether to use scan for MLP layers. Defaults to False.

  • scan_mlp_chunk_size (int, optional) – Chunk size for scan MLP. Defaults to 1024.

  • number_rep_kv (int, optional) – Number of repetitions for key/value heads. Defaults to 1.

  • bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.

  • attention_dropout (float, optional) – Dropout probability for attention. Defaults to 0.0.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – RoPE scaling configuration. Defaults to None.

  • attention_bias (bool, optional) – Whether to use bias in attention layers. Defaults to False.

  • initialization_of_moe (bool, optional) – Whether MoE layers are being initialized. Defaults to False.

  • **kwargs – Additional keyword arguments (ignored).

Return type

A tuple of the following

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.

Parameters
  • *args – Additional positional arguments (unused).

  • **kwargs – Additional keyword arguments (unused).

Returns

A tuple of partition rules, where each rule is a tuple

containing a regex pattern for parameter names and the corresponding PartitionSpec.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#

Returns a tuple of parameter names for which weight decay should be excluded.

Returns

An empty tuple, indicating no specific weight decay exclusions for this model.

Return type

tuple

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size specifically for frequency-based position embeddings.

If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for frequency encoding.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size specifically for mask-based position embeddings.

If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for mask encoding.

Return type

int

model_type: str = 'mixtral'#
static rng_keys()[source]#

Returns the names of the random number generator keys used by the model.

Returns

A tuple containing “params”, “dropout”, and “jitter” as the RNG keys.

Return type

tuple

class easydel.__init__.MixtralForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Mixtral model with a Causal Language Modeling head.

This model consists of the base Mixtral transformer (MixtralModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. It also handles the calculation of the auxiliary loss from the MoE layers.

config#

Configuration object for the model.

Type

MixtralConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Mixtral transformer model.

Type

MixtralModel

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

num_experts#

Total number of experts.

Type

int

num_experts_per_tok#

Number of experts to route per token.

Type

int

class easydel.__init__.MixtralForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Mixtral model with a Sequence Classification head.

This model consists of the base Mixtral transformer (MixtralModel) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the first token) to the number of classes for classification. It also handles the calculation of the auxiliary loss from the MoE layers.

config#

Configuration object for the model.

Type

MixtralConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Mixtral transformer model.

Type

MixtralModel

score#

The linear layer for classification.

Type

ParallelLinear

num_experts#

Total number of experts.

Type

int

num_experts_per_tok#

Number of experts to route per token.

Type

int

class easydel.__init__.MixtralModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Mixtral model transformer.

This class represents the core transformer architecture of the Mixtral model, consisting of an embedding layer, multiple MixtralDecoderLayer layers (with sparse MoE), and a final layer normalization.

config#

Configuration object for the model.

Type

MixtralConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[MixtralDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

class easydel.__init__.MptAttentionConfig(attn_type='multihead_attention', attn_pdrop=0, attn_impl='torch', clip_qkv=None, softmax_scale=None, prefix_lm=False, qk_ln=False, attn_uses_sequence_id=False, alibi=True, alibi_bias_max=8, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the attention related configuration of a [MptModel].

Parameters
  • attn_type (str, optional, defaults to “multihead_attention”) – The type of attention to use. Can be either “multihead_attention” or “multiquery_attention”.

  • attn_pdrop (float, optional, defaults to 0.0) – The dropout probability applied to the attention output.

  • attn_impl (str, optional, defaults to “torch”) – The implementation of the attention mechanism. Can be either “torch” or “flash”.

  • clip_qkv (float, optional) – The clip value applied to the query, key, and value tensors.

  • softmax_scale (float, optional) – The scale factor applied to the softmax function in the attention layer.

  • prefix_lm (bool, optional, defaults to False) – Whether to use a prefix LM.

  • qk_ln (bool, optional, defaults to False) – Whether to apply layer normalization to the query and key tensors.

  • attn_uses_sequence_id (bool, optional, defaults to False) – Whether the attention layer uses sequence IDs.

  • alibi (bool, optional, defaults to True) – Whether to use the ALiBi (Attention with Linear Biases) method.

  • alibi_bias_max (int, optional, defaults to 8) – The maximum value for the ALiBi bias.

classmethod from_pretrained(pretrained_model_name_or_path, **kwargs) EasyDeLBaseConfig[source]#

Loads attention configuration from a pretrained model configuration file.

Parameters
  • cls (type) – The class itself.

  • pretrained_model_name_or_path (str) – Path or identifier of the pretrained model.

  • **kwargs – Additional keyword arguments passed to get_config_dict and from_dict.

Returns

An instance of MptAttentionConfig loaded from the pretrained model.

Return type

EasyDeLBaseConfig

class easydel.__init__.MptConfig(d_model: int = 2048, n_heads: int = 16, n_layers: int = 24, expansion_ratio: int = 4, max_seq_len: int = 2048, vocab_size: int = 50368, resid_prob_drop: float = 0.0, layer_norm_epsilon: float = 1e-05, emb_prob_drop: float = 0.0, learned_pos_emb: bool = True, attn_config: Optional[MptAttentionConfig] = None, init_device: str = 'cpu', logit_scale: Optional[Union[float, str]] = None, no_bias: bool = True, verbose: int = 0, embedding_fraction: float = 1.0, norm_type: str = 'low_precision_layernorm', use_cache: bool = False, initializer_range=0.02, alibi: bool = True, use_bias: bool = False, act_fn: str = 'gelu', qk_ln: bool = False, use_lm_head: bool = False, use_norm_bias: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • d_model (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.

  • n_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.

  • n_layers (int, optional, defaults to 24) – Number of hidden layers in the Transformer encoder.

  • expansion_ratio (int, optional, defaults to 4) – Expansion ratio of the feed-forward layer.

  • max_seq_len (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • vocab_size (int, optional, defaults to 50368) – Vocabulary size of the MPT model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • resid_prob_drop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • emb_prob_drop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.

  • learned_pos_emb (bool, optional, defaults to True) – Whether to learn positional embeddings.

  • attn_config ([MptAttentionConfig], optional) – The configuration of the attention layer.

  • init_device (str, optional, defaults to “cpu”) – The device to initialize the model on.

  • logit_scale (float or str, optional) – The logit scale. If set to “inv_sqrt_d_model”, the logit scale is calculated as 1 / math.sqrt(d_model).

  • no_bias (bool, optional, defaults to True) – Whether to use bias in the linear layers.

  • verbose (int, optional, defaults to 0) – The verbosity level.

  • embedding_fraction (float, optional, defaults to 1.0) – The fraction of the embedding matrix to use.

  • norm_type (str, optional, defaults to “low_precision_layernorm”) – The type of layer normalization to use.

  • use_cache (bool, optional, defaults to False) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • alibi (bool, optional, defaults to True) – Whether to use ALiBi (Attention with Linear Biases) method.

  • use_bias (bool, optional, defaults to False) – Whether to use bias in the linear layers.

  • act_fn (str, optional, defaults to “gelu”) – The activation function to use.

  • qk_ln (bool, optional, defaults to False) – Whether to apply layer normalization to the query and key tensors.

  • use_lm_head (bool, optional, defaults to False) – Whether to use a language modeling head.

  • use_norm_bias (bool, optional, defaults to False) – Whether to use bias in the layer normalization layers.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Attaches custom arguments to the configuration object.

This method allows adding or overriding configuration attributes dynamically. It primarily sets attributes related to gradient checkpointing and quantization bits.

Parameters
  • gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.

  • bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.

  • **kwargs – Additional keyword arguments (ignored).

attribute_map: dict[str, str] = {'hidden_size': 'd_model', 'max_position_embeddings': 'max_seq_len', 'num_attention_heads': 'n_heads', 'num_hidden_layers': 'n_layers', 'tie_word_embeddings': 'use_lm_head'}#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.

Parameters
  • *args – Additional positional arguments (unused).

  • **kwargs – Additional keyword arguments (unused).

Returns

A tuple of partition rules, where each rule is a tuple

containing a regex pattern for parameter names and the corresponding PartitionSpec.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size specifically for frequency-based position embeddings.

If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_seq_len.

Returns

The granted maximum position embedding size for frequency encoding.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size specifically for mask-based position embeddings.

If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_seq_len.

Returns

The granted maximum position embedding size for mask encoding.

Return type

int

model_type: str = 'mpt'#
class easydel.__init__.MptForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

MPT model with a language modeling head.

This model extends the base MptModel by adding a linear layer (lm_head) on top to predict the next token in a sequence, making it suitable for causal language modeling tasks.

config#

Configuration object for the model.

Type

MptConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

transformer#

The core MPT transformer model.

Type

MptModel

lm_head#

The language modeling head. If use_lm_head in the config is True (tying embeddings), this will be None.

Type

ParallelLinear, optional

class easydel.__init__.MptModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

MPT model implementation.

This class implements the main MPT transformer model architecture, consisting of an embedding layer (token and optional positional), multiple MptBlock layers, and a final layer normalization.

config#

Configuration object for the model.

Type

MptConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

wte#

Token embedding layer.

Type

nn.Embed

emb_drop#

Dropout layer applied after embeddings.

Type

nn.Dropout

blocks#

List of transformer blocks.

Type

tp.List[MptBlock]

norm_f#

Final layer normalization.

Type

nn.LayerNorm

alibi#

Precomputed ALiBi tensor if using ALiBi.

Type

chex.Array, optional

property alibi#
class easydel.__init__.OPTConfig(vocab_size: int = 50272, hidden_size: int = 768, num_hidden_layers: int = 12, ffn_dim: int = 3072, max_position_embeddings: int = 2048, do_layer_norm_before: bool = True, _remove_final_layer_norm: bool = False, word_embed_proj_dim: int = None, dropout: float = 0.1, attention_dropout: float = 0.0, num_attention_heads: int = 12, activation_function: str = 'relu', layerdrop: float = 0.0, init_std: float = 0.02, use_cache: bool = True, pad_token_id: int = 1, bos_token_id: int = 2, eos_token_id: int = 2, enable_bias: bool = True, layer_norm_elementwise_affine: bool = True, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50272) – Vocabulary size of the OPT model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.

  • num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • ffn_dim (int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).

  • do_layer_norm_before (bool, optional, defaults to True) – Whether to perform layer normalization before the attention block.

  • _remove_final_layer_norm (bool, optional, defaults to False) – Whether to remove the final layer norm.

  • word_embed_proj_dim (int, optional) – The dimension of the word embedding projection. If not provided, it will default to hidden_size.

  • dropout (float, optional, defaults to 0.1) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • num_attention_heads (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.

  • activation_function (str or function, optional, defaults to “relu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • layerdrop (float, optional, defaults to 0.0) – The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more details.

  • init_std (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 1) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 2) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • enable_bias (bool, optional, defaults to True) – Whether to use bias in the linear layers.

  • layer_norm_elementwise_affine (bool, optional, defaults to True) – Whether to use elementwise affine in the layer normalization layers.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

attach_custom_arguments(vocab_size: int = 50272, hidden_size: int = 768, num_hidden_layers: int = 12, ffn_dim: int = 3072, max_position_embeddings: int = 2048, do_layer_norm_before: bool = True, _remove_final_layer_norm: bool = False, word_embed_proj_dim: int = None, dropout: float = 0.1, attention_dropout: float = 0.0, num_attention_heads: int = 12, activation_function: str = 'relu', layerdrop: float = 0.0, init_std: float = 0.02, use_cache: bool = True, pad_token_id: int = 1, bos_token_id: int = 2, eos_token_id: int = 2, enable_bias: bool = True, layer_norm_elementwise_affine: bool = True, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Attaches custom arguments to the configuration object.

This method allows dynamically adding or overriding configuration attributes. It iterates through the provided arguments and sets them as attributes of the configuration object if they don’t already exist.

Parameters
  • vocab_size (int, optional) – Vocabulary size. Defaults to 50272.

  • hidden_size (int, optional) – Dimensionality of the encoder layers. Defaults to 768.

  • num_hidden_layers (int, optional) – Number of hidden layers. Defaults to 12.

  • ffn_dim (int, optional) – Dimensionality of the feed-forward layer. Defaults to 3072.

  • max_position_embeddings (int, optional) – Maximum sequence length. Defaults to 2048.

  • do_layer_norm_before (bool, optional) – Whether to apply layer norm before attention. Defaults to True.

  • _remove_final_layer_norm (bool, optional) – Whether to remove the final layer norm. Defaults to False.

  • word_embed_proj_dim (int, optional) – Dimension of the word embedding projection. Defaults to hidden_size.

  • dropout (float, optional) – Dropout probability. Defaults to 0.1.

  • attention_dropout (float, optional) – Attention dropout probability. Defaults to 0.0.

  • num_attention_heads (int, optional) – Number of attention heads. Defaults to 12.

  • activation_function (str, optional) – Activation function name. Defaults to “relu”.

  • layerdrop (float, optional) – LayerDrop probability. Defaults to 0.0.

  • init_std (float, optional) – Initialization standard deviation. Defaults to 0.02.

  • use_cache (bool, optional) – Whether to use key/value cache. Defaults to True.

  • pad_token_id (int, optional) – Padding token ID. Defaults to 1.

  • bos_token_id (int, optional) – Beginning-of-sequence token ID. Defaults to 2.

  • eos_token_id (int, optional) – End-of-sequence token ID. Defaults to 2.

  • enable_bias (bool, optional) – Whether to use bias in linear layers. Defaults to True.

  • layer_norm_elementwise_affine (bool, optional) – Whether layer norm uses elementwise affine parameters. Defaults to True.

  • gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.

  • **kwargs – Additional keyword arguments to attach.

get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

keys_to_ignore_at_inference = ['past_key_values']#
model_type: str = 'opt'#
class easydel.__init__.OPTForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

OPT Model with a Causal Language Modeling head.

This model consists of the base OPTModel followed by a linear layer (the language modeling head) to predict the next token logits.

config#

Configuration object for the model.

Type

OPTConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

model#

The base OPT model.

Type

OPTModel

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

get_decoder()[source]#

Gets the decoder module from the model.

get_input_embeddings()[source]#

Gets the input embeddings from the model.

get_output_embeddings()[source]#

Gets the output embeddings (language modeling head).

prepare_inputs_for_generation(input_ids, max_length, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds

Returns

A dictionary of the past_key_values, attention_mask and position ids

set_decoder(decoder)[source]#

Sets the decoder module for the model.

set_input_embeddings(value)[source]#

Sets the input embeddings for the model.

set_output_embeddings(new_embeddings)[source]#

Sets the output embeddings (language modeling head).

update_inputs_for_generation(model_outputs, model_kwargs)[source]#
class easydel.__init__.OPTModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Base OPT Model class.

This class represents the core OPT model architecture, consisting primarily of the OPTDecoder.

config#

Configuration object for the model.

Type

OPTConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

decoder#

The OPT decoder stack.

Type

OPTDecoder

get_input_embeddings()[source]#

Gets the input embeddings from the model.

set_input_embeddings(value)[source]#

Sets the input embeddings for the model.

class easydel.__init__.ORPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 1e-06, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'ORPOTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, max_length: ~typing.Optional[int] = 1024, max_prompt_length: ~typing.Optional[int] = 512, max_completion_length: ~typing.Optional[int] = None, beta: float = 0.1, disable_dropout: bool = True, label_pad_token_id: int = -100, padding_value: ~typing.Optional[int] = None, generate_during_eval: bool = False, is_encoder_decoder: ~typing.Optional[bool] = None, dataset_num_proc: ~typing.Optional[int] = None)[source]#

Bases: TrainingArguments

Configuration class for ORPO training settings.

This class inherits from TrainingArguments and holds configuration parameters specific to the ORPO model training. The dataclass automatically generates an initializer, and the __post_init__ method further processes some of the parameters after object initialization.

model_name#

The name of the model. Default is “ORPOTrainer”.

Type

str

learning_rate#

The learning rate used during training. Default is 1e-6.

Type

float

max_length#

The maximum allowed sequence length for the input. Default is 1024.

Type

Optional[int]

max_prompt_length#

The maximum allowed length of the prompt portion of the input. Default is 512.

Type

Optional[int]

max_completion_length#

The maximum allowed length of the completion. If not provided, it is set to max_length - max_prompt_length.

Type

Optional[int]

beta#

A hyperparameter beta, with a default value of 0.1.

Type

float

disable_dropout#

Flag to disable dropout during training. Default is True.

Type

bool

label_pad_token_id#

The token id used for padding labels. Default is -100.

Type

int

padding_value#

The value used for padding sequences. Default is None.

Type

Optional[int]

generate_during_eval#

Flag indicating whether to generate sequences during evaluation. Default is False.

Type

bool

is_encoder_decoder#

Flag to indicate if the model is encoder-decoder. Default is None.

Type

Optional[bool]

model_init_kwargs#

Additional keyword arguments for model initialization. Default is None.

Type

Optional[Dict[str, Any]]

dataset_num_proc#

Number of processes to use for dataset processing. Default is None.

Type

Optional[int]

max_sequence_length#

Computed attribute representing the maximum sequence length used for training. It is set in the __post_init__ method.

Type

int

beta: float = 0.1#
dataset_num_proc: Optional[int] = None#
disable_dropout: bool = True#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

generate_during_eval: bool = False#
is_encoder_decoder: Optional[bool] = None#
label_pad_token_id: int = -100#
learning_rate: float = 1e-06#
max_completion_length: Optional[int] = None#
max_length: Optional[int] = 1024#
max_prompt_length: Optional[int] = 512#
model_name: str = 'ORPOTrainer'#
padding_value: Optional[int] = None#
replace(**kwargs)#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

class easydel.__init__.ORPOTrainer(arguments: ORPOConfig, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, data_collator: Optional[DPODataCollatorWithPadding] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, processing_class: Optional[Any] = None)[source]#

Bases: Trainer

arguments: ORPOConfig#
build_tokenized_answer(prompt: str, answer: str) Dict[str, ndarray][source]#

Tokenizes a prompt and answer pair, handling special tokens and padding/truncation.

Parameters
  • prompt (str) – The prompt text.

  • answer (str) – The answer text.

Returns

A dictionary containing the tokenized prompt and answer, along with attention masks.

Return type

tp.Dict[str, np.ndarray]

Raises

ValueError – If there’s a mismatch in token lengths.

configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Creates a data collection function for batching.

For DPO training, this method simply returns the pre-configured data_collator.

Parameters
  • max_sequence_length (int) – The maximum sequence length (not used in this implementation).

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.

Returns

The data collator function.

Return type

tp.Callable

tokenize_row(feature: Dict[str, str], state: Optional[object] = None) Dict[str, ndarray][source]#

Tokenizes a single row of data from the ORPO dataset.

This method tokenizes the prompt, chosen response, and rejected response, handles padding and truncation, and prepares the data for input to the DPO model.

Parameters
  • feature (tp.Dict) – A dictionary containing the “prompt”, “chosen”, and “rejected” texts.

  • state (EasyDeLState, optional) – Not used in this implementation. Defaults to None.

Returns

A dictionary containing the tokenized prompt, chosen response, and rejected response,

along with attention masks and labels.

Return type

tp.Dict

Raises

ValueError – If the input data types are incorrect.

class easydel.__init__.Olmo2Config(vocab_size=50304, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=2048, initializer_range=0.02, use_cache=True, pad_token_id=1, bos_token_id=None, eos_token_id=50279, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, rms_norm_eps=1e-05, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the configuration of a [Olmo2Model]. It is used to instantiate an OLMo2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf).

Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50304) – Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [Olmo2Model]

  • hidden_size (int, optional, defaults to 4096) – Dimension of the hidden representations.

  • intermediate_size (int, optional, defaults to 11008) – Dimension of the MLP representations.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer decoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer decoder.

  • num_key_value_heads (int, optional) – This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to num_attention_heads.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) in the decoder.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 1) – Padding token id.

  • bos_token_id (int, optional) – Beginning of stream token id.

  • eos_token_id (int, optional, defaults to 50279) – End of stream token id.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie weight embeddings

  • rope_theta (float, optional, defaults to 10000.0) – The base period of the RoPE embeddings.

  • rope_scaling (tp.Dict, optional) – Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is {“type”: strategy name, “factor”: scaling factor}. When using this flag, don’t update max_position_embeddings to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions.

  • attention_bias (bool, defaults to False, optional, defaults to False) – Whether to use a bias in the query, key, value and output projection layers during self-attention.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • rms_norm_eps (float, optional, defaults to 1e-05) – The epsilon used by the rms normalization layers.

>>> from transformers import Olmo2Model, Olmo2Config
>>> # Initializing a Olmo2 7B style configuration
>>> configuration = Olmo2Config()
>>> # Initializing a model from the Olmo2 7B style configuration
>>> model = Olmo2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None)[source]#

Attaches custom arguments to the configuration object.

This method allows adding or overriding configuration attributes dynamically. It primarily sets attributes related to gradient checkpointing, MLP scanning, and quantization bits.

Parameters
  • gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.

  • use_scan_mlp (bool, optional) – Whether to use scan for MLP layers. Defaults to False.

  • scan_mlp_chunk_size (int, optional) – Chunk size for scan MLP. Defaults to 1024.

  • bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.

Parameters
  • *args – Additional positional arguments (unused).

  • **kwargs – Additional keyword arguments (unused).

Returns

A tuple of partition rules, where each rule is a tuple

containing a regex pattern for parameter names and the corresponding PartitionSpec.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size specifically for frequency-based position embeddings.

If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for frequency encoding.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size specifically for mask-based position embeddings.

If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for mask encoding.

Return type

int

keys_to_ignore_at_inference = ['past_key_values']#
model_type: str = 'olmo2'#
class easydel.__init__.Olmo2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

OLMo-2 model with a Causal Language Modeling head.

This model consists of the base OLMo-2 transformer (Olmo2Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction.

config#

Configuration object for the model.

Type

Olmo2Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core OLMo-2 transformer model.

Type

Olmo2Model

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

class easydel.__init__.Olmo2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

OLMo-2 model with a Sequence Classification head.

This model consists of the base OLMo-2 transformer (Olmo2Model) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token) to the number of classes for classification.

config#

Configuration object for the model.

Type

Olmo2Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core OLMo-2 transformer model.

Type

Olmo2Model

score#

The linear layer for classification.

Type

ParallelLinear

class easydel.__init__.Olmo2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base OLMo-2 model transformer.

This class represents the core transformer architecture of the OLMo-2 model, consisting of an embedding layer, multiple Olmo2DecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Olmo2Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[Olmo2DecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

class easydel.__init__.OlmoConfig(vocab_size=50304, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=2048, initializer_range=0.02, use_cache=True, pad_token_id=1, bos_token_id=None, eos_token_id=50279, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, clip_qkv=None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50304) – Vocabulary size of the Olmo model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 11008) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 1) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 50279) – The id of the end-of-sequence token.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • attention_bias (bool, optional, defaults to False) – Whether to use attention bias.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • clip_qkv (float, optional) – The clip value applied to the query, key, and value tensors.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • bits (int, optional) – The number of bits to quantize the model to.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None)[source]#

Attaches custom arguments to the configuration object.

This method allows adding or overriding configuration attributes dynamically. It primarily sets attributes related to gradient checkpointing, MLP scanning, and quantization bits.

Parameters
  • gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.

  • use_scan_mlp (bool, optional) – Whether to use scan for MLP layers. Defaults to False.

  • scan_mlp_chunk_size (int, optional) – Chunk size for scan MLP. Defaults to 1024.

  • bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.

Parameters
  • *args – Additional positional arguments (unused).

  • **kwargs – Additional keyword arguments (unused).

Returns

A tuple of partition rules, where each rule is a tuple

containing a regex pattern for parameter names and the corresponding PartitionSpec.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size specifically for frequency-based position embeddings.

If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for frequency encoding.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size specifically for mask-based position embeddings.

If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for mask encoding.

Return type

int

model_type: str = 'olmo'#
class easydel.__init__.OlmoForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

OLMo model with a Causal Language Modeling head.

This model consists of the base OLMo transformer (OlmoModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction.

config#

Configuration object for the model.

Type

OlmoConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

transformer#

The core OLMo transformer model.

Type

OlmoModel

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

class easydel.__init__.OlmoModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base OLMo model transformer.

This class represents the core transformer architecture of the OLMo model, consisting of an embedding layer and multiple OlmoDecoderLayer layers. Note that OLMo does not have a final layer normalization.

config#

Configuration object for the model.

Type

OlmoConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[OlmoDecoderLayer]

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

class easydel.__init__.OpenELMConfig(vocab_size: int = 32000, max_context_length: int = 2048, num_transformer_layers: int = 12, model_dim: int = 2048, head_dim: int = 128, qkv_multipliers: Union[Number, List[Number]] = 1.0, num_query_heads: Optional[int] = None, num_gqa_groups: int = 1, ffn_multipliers: Union[Number, List[Number]] = 4.0, ffn_with_glu: bool = True, ffn_dim_divisor: int = 256, activation_fn_name: str = 'swish', normalization_layer_name: str = 'rms_norm', normalize_qk_projections: bool = False, share_input_output_layers: bool = False, rope_freq_constant: int = 10000, rope_max_length: int = 4096, initializer_range: float = 0.02, use_cache: bool = True, bos_token_id: int = 1, eos_token_id: int = 2, rope_scaling: Dict[str, Union[str, float]] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the OpenELM model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • max_context_length (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • num_transformer_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • model_dim (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.

  • head_dim (int, optional, defaults to 128) – Dimensionality of the attention heads.

  • qkv_multipliers (float or list of float, optional, defaults to 1.0) – The multiplier for the query, key, and value projections.

  • num_query_heads (int, optional) – Number of query heads. If not provided, it will be calculated based on model_dim and head_dim.

  • num_gqa_groups (int, optional, defaults to 1) – Number of GQA (Grouped Query Attention) groups.

  • ffn_multipliers (float or list of float, optional, defaults to 4.0) – The multiplier for the feed-forward network.

  • ffn_with_glu (bool, optional, defaults to True) – Whether to use a gated linear unit (GLU) in the feed-forward network.

  • ffn_dim_divisor (int, optional, defaults to 256) – The divisor for the feed-forward network dimension.

  • activation_fn_name (str, optional, defaults to “swish”) – The activation function to use.

  • normalization_layer_name (str, optional, defaults to “rms_norm”) – The normalization layer to use.

  • normalize_qk_projections (bool, optional, defaults to False) – Whether to normalize the query and key projections.

  • share_input_output_layers (bool, optional, defaults to False) – Whether to share the input and output layers.

  • rope_freq_constant (int, optional, defaults to 10000) – The frequency constant for Rotary Position Embeddings (RoPE).

  • rope_max_length (int, optional, defaults to 4096) – The maximum length for RoPE.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • bits (int, optional) – The number of bits to quantize the model to.

attribute_map: dict[str, str] = {'tie_word_embedding': 'share_input_output_layers'}#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.

Parameters
  • *args – Additional positional arguments (unused).

  • **kwargs – Additional keyword arguments (unused).

Returns

A tuple of partition rules, where each rule is a tuple

containing a regex pattern for parameter names and the corresponding PartitionSpec.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#

Returns a tuple of parameter names for which weight decay should be excluded.

Returns

A tuple containing ‘bias’, ‘normalization’, and ‘emb’ as exclusions.

Return type

tuple

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size specifically for frequency-based position embeddings.

If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_context_length.

Returns

The granted maximum position embedding size for frequency encoding.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size specifically for mask-based position embeddings.

If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_context_length.

Returns

The granted maximum position embedding size for mask encoding.

Return type

int

model_type: str = 'openelm'#
static rng_keys()[source]#

Returns the names of the random number generator keys used by the model.

Returns

A tuple containing “params” and “dropout” as the RNG keys.

Return type

tuple

class easydel.__init__.OpenELMForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

OpenELM model with a Causal Language Modeling head.

This model consists of the base OpenELM transformer (OpenELMModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.

config#

Configuration object for the model.

Type

OpenELMConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

transformer#

The core OpenELM transformer model.

Type

OpenELMModel

lm_head#

The linear layer for projecting hidden states to vocabulary logits. This is None if config.share_input_output_layers is True.

Type

ParallelLinear, optional

class easydel.__init__.OpenELMModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base OpenELM model transformer.

This class represents the core transformer architecture of the OpenELM model, consisting of an embedding layer, multiple OpenELMDecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

OpenELMConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

token_embeddings#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[OpenELMDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

class easydel.__init__.PartitionAxis(data_parallel_axis: str = 'dp', fully_sharded_data_parallel_axis: str = 'fsdp', tensor_parallel_axis: str = 'tp', sequence_parallel_axis: str = 'sp', expert_parallel_axis: str = 'ep', batch_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, query_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, head_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, key_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, hidden_state_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, mlp_intermediate_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, vocab_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, expert_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, expert_gate_axis: Optional[Union[Tuple[str, ...], str, Any]] = None, attention_dim_axis: Optional[Union[Tuple[str, ...], str, Any]] = None, bias_head_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = None, bias_key_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = None, generation_batch_axis: Optional[Union[Tuple[str, ...], str, Any]] = None, generation_query_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = None, generation_head_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, generation_key_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, generation_attention_dim_axis: Optional[Union[Tuple[str, ...], str, Any]] = None)[source]#

Bases: object

Configuration for partitioning model axes across a device mesh.

Defines the mesh dimension names for standard parallelism strategies and maps logical model axes to these dimensions. Allows overriding defaults.

Mesh Dimensions:

data_parallel_axis: Name for data parallel mesh dim. Default: “dp”. fully_sharded_data_parallel_axis: Name for FSDP mesh dim. Default: “fsdp”. tensor_parallel_axis: Name for tensor parallel mesh dim. Default: “tp”. sequence_parallel_axis: Name for sequence parallel mesh dim. Default: “sp”. expert_parallel_axis: Name for expert parallel mesh dim (MoE). Default: “ep”.

Logical Model Axes:

Maps logical tensor axes (like batch, sequence, hidden) to one or more mesh dimension names defined above, or None if not partitioned. Defaults are derived from the standard mesh dimension names but can be overridden during instantiation. For example, head_axis defaults to the value of tensor_parallel_axis (‘tp’).

attention_dim_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
batch_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
bias_head_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
bias_key_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
data_parallel_axis: str = 'dp'#
expert_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
expert_gate_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
expert_parallel_axis: str = 'ep'#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

fully_sharded_data_parallel_axis: str = 'fsdp'#
generation_attention_dim_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
generation_batch_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
generation_head_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
generation_key_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
generation_query_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
head_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
hidden_state_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
key_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
mlp_intermediate_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
query_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
replace(**kwargs)#
sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
sequence_parallel_axis: str = 'sp'#
tensor_parallel_axis: str = 'tp'#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

vocab_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
class easydel.__init__.Phi3Config(vocab_size=32064, hidden_size=3072, intermediate_size=8192, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, resid_pdrop=0.0, embd_pdrop=0.0, attention_dropout=0.0, hidden_act='silu', max_position_embeddings=4096, original_max_position_embeddings=4096, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, bos_token_id=1, eos_token_id=32000, pad_token_id=32000, sliding_window=None, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32064) – Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 3072) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 8192) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.

  • resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • embd_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 4096) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • original_max_position_embeddings (int, optional, defaults to 4096) – The original maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 32000) – The id of the end-of-sequence token.

  • pad_token_id (int, optional, defaults to 32000) – The index of the padding token in the vocabulary.

  • sliding_window (int, optional) – The sliding window size.

  • bits (int, optional) – The number of bits to quantize the model to.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size specifically for frequency-based position embeddings.

If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for frequency encoding.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size specifically for mask-based position embeddings.

If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for mask encoding.

Return type

int

model_type: str = 'phi3'#
class easydel.__init__.Phi3ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Phi-3 model with a Causal Language Modeling head.

This model consists of the base Phi-3 transformer (Phi3Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.

config#

Configuration object for the model.

Type

Phi3Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Phi-3 transformer model.

Type

Phi3Model

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

class easydel.__init__.Phi3Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Phi-3 model transformer.

This class represents the core transformer architecture of the Phi-3 model, consisting of an embedding layer, multiple Phi3DecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Phi3Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

embed_dropout#

Dropout layer applied after embeddings.

Type

nn.Dropout

layers#

List of decoder layers.

Type

tp.List[Phi3DecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

class easydel.__init__.PhiConfig(vocab_size=51200, hidden_size=2048, intermediate_size=8192, num_hidden_layers=24, num_attention_heads=32, num_key_value_heads=None, resid_pdrop=0.0, embd_pdrop=0.0, attention_dropout=0.0, hidden_act='gelu_new', max_position_embeddings=2048, initializer_range=0.02, layer_norm_eps=1e-05, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, partial_rotary_factor=0.5, qk_layernorm=False, bos_token_id=1, eos_token_id=2, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 51200) – Vocabulary size of the Phi model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 8192) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 24) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.

  • resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • embd_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • hidden_act (str or function, optional, defaults to “gelu_new”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • partial_rotary_factor (float, optional, defaults to 0.5) – The factor for partial rotary embeddings.

  • qk_layernorm (bool, optional, defaults to False) – Whether to apply layer normalization to the query and key tensors.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • bits (int, optional) – The number of bits to quantize the model to.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

attribute_map: dict[str, str] = {'hidden_size': 'n_embd', 'max_position_embeddings': 'n_positions', 'num_attention_heads': 'num_attention_heads', 'num_hidden_layers': 'num_hidden_layers'}#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size specifically for frequency-based position embeddings.

If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for frequency encoding.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size specifically for mask-based position embeddings.

If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for mask encoding.

Return type

int

model_type: str = 'phi'#
class easydel.__init__.PhiForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Phi model with a Causal Language Modeling head.

This model consists of the base Phi transformer (PhiModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.

config#

Configuration object for the model.

Type

PhiConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

transformer#

The core Phi transformer model.

Type

PhiModel

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

class easydel.__init__.PhiModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Phi model transformer.

This class represents the core transformer architecture of the Phi model, consisting of an embedding layer, multiple PhiDecoderLayer layers, and a final layer normalization.

config#

Configuration object for the model.

Type

PhiConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[PhiDecoderLayer]

final_layernorm#

Final layer normalization.

Type

nn.LayerNorm

embed_dropout#

Dropout layer applied after embeddings.

Type

nn.Dropout

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

class easydel.__init__.PhiMoeConfig(vocab_size=32064, hidden_size=4096, intermediate_size=6400, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, hidden_act='silu', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=1000000.0, rope_scaling=None, sliding_window=None, attention_dropout=0.0, num_experts_per_tok=2, num_local_experts=16, output_router_logits=False, router_aux_loss_coef=0.001, router_jitter_noise=0.01, input_jitter_noise=0.0, attention_bias=False, embd_pdrop: float = 0.0, lm_head_bias=False, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32064) – Vocabulary size of the PhiMoE model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [PhiMoEModel]

  • hidden_size (int, optional, defaults to 4096) – Dimension of the hidden representations.

  • intermediate_size (int, optional, defaults to 6400) – Dimension of the MLP representations.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 8) – This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) in the decoder.

  • max_position_embeddings (int, optional, defaults to 4096*32) – The maximum sequence length that this model might ever be used with. Mixtral’s sliding window attention allows sequence of up to 4096*32 tokens.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-05) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – The id of the padding token.

  • bos_token_id (int, optional, defaults to 1) – The id of the “beginning-of-sequence” token.

  • eos_token_id (int, optional, defaults to 2) – The id of the “end-of-sequence” token.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether the model’s input and output word embeddings should be tied.

  • rope_theta (float, optional, defaults to 10000.0) – The base period of the RoPE embeddings.

  • rope_scaling (dict, optional) – The scaling strategy for the RoPE embeddings. If None, no scaling is applied. If a dictionary, it must contain the following keys: type, short_factor, long_factor, short_mscale, long_mscale and original_max_position_embeddings. The type must be longrope, the short_mscale and long_scale must be numbers, the short_factor and long_factor must be lists of numbers with the same length as half of the attention head size and the original_max_position_embeddings must be an integer.

  • sliding_window (int, optional) – Sliding window attention window size. If not specified, will default to 262144.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • num_experts_per_tok (int, optional, defaults to 2) – The number of experts to root per-token, can be also interpreted as the top-p routing parameter

  • num_local_experts (int, optional, defaults to 16) – Number of experts per Sparse MLP layer.

  • output_router_logits (bool, optional, defaults to False) – Whether or not the router logits should be returned by the model. Enabeling this will also allow the model to output the auxiliary loss. See [here]() for more details

  • router_aux_loss_coef (float, optional, defaults to 0.0) – The aux loss factor for the total loss.

  • router_jitter_noise (float, optional, defaults to 0.01) – Amount of noise to add to the router.

  • bits (int, optional) – The number of bits to quantize the model to.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

attach_custom_arguments(bits: Optional[int] = None, embd_pdrop: float = 0.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Attaches custom arguments to the configuration object.

This method allows dynamically adding or overriding configuration attributes. It primarily sets attributes related to quantization, dropout, and gradient checkpointing. Any additional keyword arguments are also set as attributes if they don’t already exist.

Parameters
  • bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.

  • embd_pdrop (float, optional) – Dropout probability for embeddings. Defaults to 0.0.

  • gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.

  • **kwargs – Additional keyword arguments to attach.

get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size specifically for frequency-based position embeddings.

If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for frequency encoding.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size specifically for mask-based position embeddings.

If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for mask encoding.

Return type

int

model_type: str = 'phimoe'#
class easydel.__init__.PhiMoeForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

PhiMoE model with a Causal Language Modeling head.

This model consists of the base PhiMoE transformer (PhiMoeModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.

config#

Configuration object for the model.

Type

PhiMoeConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core PhiMoE transformer model.

Type

PhiMoeModel

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

class easydel.__init__.PhiMoeModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base PhiMoE model transformer.

This class represents the core transformer architecture of the PhiMoE model, consisting of an embedding layer, multiple PhiMoeDecoderLayer layers (which include Sparse Mixture of Experts blocks), and a final RMS normalization layer.

config#

Configuration object for the model.

Type

PhiMoeConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[PhiMoeDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

embed_dropout#

Dropout layer applied after embeddings.

Type

nn.Dropout

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

class easydel.__init__.PixtralVisionConfig(hidden_size: int = 1024, intermediate_size: int = 4096, num_hidden_layers: int = 24, num_attention_heads: int = 16, num_channels: int = 3, image_size: int = 1024, patch_size: int = 16, hidden_act: str = 'gelu', attention_dropout: float = 0.0, rope_theta: float = 10000.0, initializer_range: int = 0.02, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the configuration of a [PixtralVisionModel]. It is used to instantiate an Pixtral vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to the vision encoder used by Pixtral-12B.

e.g. [pixtral-hf/pixtral-9b](https://huggingface.co/pixtral-hf/pixtral-9b)

Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.

Parameters
  • hidden_size (int, optional, defaults to 1024) – Dimension of the hidden representations.

  • intermediate_size (int, optional, defaults to 4096) – Dimension of the MLP representations.

  • num_hidden_layers (int, optional, defaults to 24) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 16) – Number of attention heads in the Transformer encoder.

  • num_channels (int, optional, defaults to 3) – Number of input channels in the input images.

  • image_size (int, optional, defaults to 1024) – Max dimension of the input images.

  • patch_size (int, optional, defaults to 16) – Size of the image patches.

  • hidden_act (str, optional, defaults to “gelu”) – Activation function used in the hidden layers.

  • attention_dropout (float, optional, defaults to 0.0) – Dropout probability for the attention layers.

  • rope_theta (float, optional, defaults to 10000.0) – The base period of the RoPE embeddings.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

Example:

```python >>> from transformers import PixtralVisionModel, PixtralVisionConfig

>>> # Initializing a Pixtral-12B style configuration
>>> config = PixtralVisionConfig()
>>> # Initializing a model (with randomly initialized weights) from the configuration
>>> model = PixtralVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'pixtral'#
class easydel.__init__.PixtralVisionModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The Pixtral Vision Model transformer.

This class implements the complete Pixtral vision model, including patch embedding via convolution and the main transformer stack.

config#

Configuration object for the model.

Type

PixtralVisionConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

patch_conv#

Convolutional layer for patch embedding.

Type

nn.Conv

transformer#

The main transformer stack.

Type

PixtralTransformer

ln_pre#

Layer normalization applied before the transformer blocks.

Type

RMSNorm

property frequencies#

Cached property to compute and retrieve RoPE frequencies.

class easydel.__init__.PyTree[source]#

Bases: _PyTreeNodeBase

Base class for dataclasses that should act like a JAX pytree node.

class easydel.__init__.Qwen2Config(vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = 0.0, fcm_max_ratio: float = 0.0, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, scan_layers: bool = True, rope_scaling: Optional[Mapping[str, str | float]] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 151936) – Vocabulary size of the Qwen-2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 22016) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 32) – Number of key and value heads for each attention layer in the Transformer encoder.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 32768) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • use_sliding_window (bool, optional, defaults to False) – Whether to use a sliding window attention.

  • sliding_window (int, optional, defaults to 4096) – The sliding window size.

  • max_window_layers (int, optional, defaults to 28) – The maximum number of layers to use for the sliding window attention.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • embd_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • fcm_min_ratio (float, optional, defaults to 0.0) – The minimum ratio for Flash Attention.

  • fcm_max_ratio (float, optional, defaults to 0.0) – The maximum ratio for Flash Attention.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • number_rep_kv (int, optional, defaults to 1) – Number of repetitions for the key and value vectors.

  • bits (int, optional) – The number of bits to quantize the model to.

  • scan_layers (bool, optional, defaults to True) – Whether to use the scan implementation for the layers.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

attach_custom_arguments(resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, attention_dropout: float = 0.0, tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = 0.0, fcm_max_ratio: float = 0.0, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, rope_theta: float = 10000.0, hidden_act: str = 'silu', scan_layers: bool = True, rope_scaling: Optional[Mapping[str, str | float]] = None, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • resid_pdrop – float: Set the dropout rate for residual connections.

  • embd_pdrop – float: Set the probability of dropping an embedding.

  • attention_dropout – float: Set the probability of dropping out the attention layer.

  • tie_word_embeddings – bool: Tie the word embeddings to the decoder.

  • gradient_checkpointing – str: Control the amount of memory used by jax.

  • fcm_min_ratio – float: Control the minimum ratio for Flash Attention.

  • fcm_max_ratio – float: Set the maximum ratio for Flash Attention.

  • use_scan_mlp – bool: Determine whether to use the scan_mlp function or not.

  • scan_mlp_chunk_size – int: Set the chunk size for scan_mlp.

  • number_rep_kv – int: Determine how many times the key and value vectors are repeated.

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization.

  • rope_theta – float: Base value for RoPE.

  • hidden_act – str: Activation function name.

  • scan_layers – bool: Determine whether to use scan layers or not.

  • rope_scaling (tp.Optional[tp.Mapping[str, str | float]], optional) – RoPE scaling configuration.

  • **kwargs – Additional keyword arguments to attach.

get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#

Returns a tuple of parameter names for which weight decay should be excluded.

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size specifically for frequency-based position embeddings.

If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for frequency encoding.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size specifically for mask-based position embeddings.

If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for mask encoding.

Return type

int

model_type: str = 'qwen2'#
static rng_keys()[source]#

Returns the names of the random number generator keys used by the model.

class easydel.__init__.Qwen2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Qwen2 model with a Causal Language Modeling head.

This model consists of the base Qwen2 transformer (Qwen2Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.

config#

Configuration object for the model.

Type

Qwen2Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Qwen2 transformer model.

Type

Qwen2Model

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

class easydel.__init__.Qwen2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Qwen2 model with a Sequence Classification head.

This model consists of the base Qwen2 transformer (Qwen2Model) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token or a pooled representation) to the number of classes for classification.

config#

Configuration object for the model.

Type

Qwen2Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Qwen2 transformer model.

Type

Qwen2Model

score#

The linear layer for classification.

Type

ParallelLinear

class easydel.__init__.Qwen2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Qwen2 model transformer.

This class represents the core transformer architecture of the Qwen2 model, consisting of an embedding layer, multiple Qwen2DecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Qwen2Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[Qwen2DecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

dropout#

Dropout layer applied after embeddings.

Type

nn.Dropout

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

class easydel.__init__.Qwen2MoeConfig(vocab_size=151936, hidden_size=2048, intermediate_size=5632, num_hidden_layers=24, num_attention_heads=16, num_key_value_heads=16, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, tie_word_embeddings=False, qkv_bias=False, rope_theta=10000.0, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, decoder_sparse_step=1, moe_intermediate_size=1408, shared_expert_intermediate_size=5632, num_experts_per_tok=4, num_experts=60, norm_topk_prob=False, output_router_logits=False, router_aux_loss_coef=0.001, mlp_only_layers=None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 151936) – Vocabulary size of the Qwen-2 MoE model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 5632) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 24) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 16) – Number of key and value heads for each attention layer in the Transformer encoder.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 32768) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • use_sliding_window (bool, optional, defaults to False) – Whether to use a sliding window attention.

  • sliding_window (int, optional, defaults to 4096) – The sliding window size.

  • max_window_layers (int, optional, defaults to 28) – The maximum number of layers to use for the sliding window attention.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • decoder_sparse_step (int, optional, defaults to 1) – The sparse step for the decoder.

  • moe_intermediate_size (int, optional, defaults to 1408) – The intermediate size of the MoE layer.

  • shared_expert_intermediate_size (int, optional, defaults to 5632) – The intermediate size of the shared expert.

  • num_experts_per_tok (int, optional, defaults to 4) – The number of experts per token.

  • num_experts (int, optional, defaults to 60) – The number of experts.

  • norm_topk_prob (bool, optional, defaults to False) – Whether to normalize the top-k probabilities.

  • output_router_logits (bool, optional, defaults to False) – Whether to output the router logits.

  • router_aux_loss_coef (float, optional, defaults to 0.001) – The coefficient for the router auxiliary loss.

  • mlp_only_layers (list of int, optional) – The layers that should only contain an MLP.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#

Returns a tuple of parameter names for which weight decay should be excluded.

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size specifically for frequency-based position embeddings.

If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for frequency encoding.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size specifically for mask-based position embeddings.

If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for mask encoding.

Return type

int

model_type: str = 'qwen2_moe'#
static rng_keys()[source]#

Returns the names of the random number generator keys used by the model.

class easydel.__init__.Qwen2MoeForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Qwen2 MoE model with a Causal Language Modeling (CLM) head.

This class wraps the base Qwen2MoeModel and adds a linear layer (language model head) to predict the next token logits.

config#

Configuration object for the model.

Type

Qwen2MoeConfig

model#

The base Qwen2 MoE model.

Type

Qwen2MoeModel

lm_head#

The language model head (linear layer).

Type

ParallelLinear

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.__init__.Qwen2MoeForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Qwen2 MoE model with a sequence classification head.

This class wraps the base Qwen2MoeModel and adds a linear layer on top to perform sequence classification tasks.

config#

Configuration object for the model.

Type

Qwen2MoeConfig

model#

The base Qwen2 MoE model.

Type

Qwen2MoeModel

score#

The sequence classification head (linear layer).

Type

ParallelLinear

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.__init__.Qwen2MoeModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Qwen2 MoE transformer model.

This class implements the core transformer architecture, including embedding layers, decoder layers, and final normalization.

config#

Configuration object for the model.

Type

Qwen2MoeConfig

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

nn.List[Qwen2MoeDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing strategy.

Type

str

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.__init__.Qwen2VLConfig(vocab_size=152064, hidden_size=8192, intermediate_size=29568, num_hidden_layers=80, num_attention_heads=64, num_key_value_heads=8, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, tie_word_embeddings=False, rope_theta=1000000.0, use_sliding_window=False, sliding_window=4096, max_window_layers=80, attention_dropout=0.0, vision_config=None, rope_scaling=None, vision_start_token_id=151652, vision_end_token_id=151653, vision_token_id=151654, image_token_id=151655, video_token_id=151656, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the configuration of a [Qwen2VLModel]. It is used to instantiate a Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).

Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 152064) – Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [Qwen2VLModel]

  • hidden_size (int, optional, defaults to 8192) – Dimension of the hidden representations.

  • intermediate_size (int, optional, defaults to 29568) – Dimension of the MLP representations.

  • num_hidden_layers (int, optional, defaults to 80) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 64) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 8) – This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 32.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) in the decoder.

  • max_position_embeddings (int, optional, defaults to 32768) – The maximum sequence length that this model might ever be used with.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-05) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether the model’s input and output word embeddings should be tied.

  • rope_theta (float, optional, defaults to 1000000.0) – The base period of the RoPE embeddings.

  • use_sliding_window (bool, optional, defaults to False) – Whether to use sliding window attention.

  • sliding_window (int, optional, defaults to 4096) – Sliding window attention (SWA) window size. If not specified, will default to 4096.

  • max_window_layers (int, optional, defaults to 80) – The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • vision_config (tp.Dict, optional) – The config for the visual encoder initialization.

  • rope_scaling (tp.Dict, optional) –

    Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer max_position_embeddings, we recommend you to update this value accordingly. Expected contents:

    rope_type (str):

    The sub-variant of RoPE to use. Can be one of [‘default’, ‘linear’, ‘dynamic’, ‘yarn’, ‘longrope’, ‘llama3’], with ‘default’ being the original RoPE implementation.

    factor (float, optional):

    Used with all rope types except ‘default’. The scaling factor to apply to the RoPE embeddings. In most scaling types, a factor of x will enable the model to handle sequences of length x * original maximum pre-trained length.

    original_max_position_embeddings (int, optional):

    Used with ‘dynamic’, ‘longrope’ and ‘llama3’. The original max position embeddings used during pretraining.

    attention_factor (float, optional):

    Used with ‘yarn’ and ‘longrope’. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the factor field to infer the suggested value.

    beta_fast (float, optional):

    Only used with ‘yarn’. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32.

    beta_slow (float, optional):

    Only used with ‘yarn’. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1.

    short_factor (tp.List[float], optional):

    Only used with ‘longrope’. The scaling factor to be applied to short contexts (< original_max_position_embeddings). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2

    long_factor (tp.List[float], optional):

    Only used with ‘longrope’. The scaling factor to be applied to long contexts (< original_max_position_embeddings). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2

    low_freq_factor (float, optional):

    Only used with ‘llama3’. Scaling factor applied to low frequency components of the RoPE

    high_freq_factor (float, optional):

    Only used with ‘llama3’. Scaling factor applied to high frequency components of the RoPE

```python >>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig

>>> # Initializing a Qwen2VL style configuration
>>> configuration = Qwen2VLConfig()
>>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2VLForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model parameters for use with distributed training.

These rules define how model parameters should be partitioned across multiple devices when using techniques like Fully Sharded Data Parallelism (FSDP), Sharded Parallelism (SP), and Tensor Parallelism (TP).

Returns

A tuple of tuples where each inner tuple contains:
  • A regex pattern matching parameter names

  • A PartitionSpec object specifying how to partition matching parameters

Return type

tuple

keys_to_ignore_at_inference = ['past_key_values']#
model_type: str = 'qwen2_vl'#
sub_configs: dict[str, 'PretrainedConfig'] = {'vision_config': <class 'easydel.__init__.modules.qwen2_vl.qwen2_vl_configuration.Qwen2VLVisionConfig'>}#
class easydel.__init__.Qwen2VLForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

get_decoder()[source]#
get_input_embeddings()[source]#
get_output_embeddings()[source]#
get_static_arguments()[source]#

Returns a tuple of static arguments required by the module’s __call__ method.

Static arguments are those that don’t change across calls and can be potentially cached or handled differently by JIT compilation. This base implementation returns an empty tuple. Subclasses should override this if they have static arguments.

Returns

A tuple containing static arguments.

Return type

tp.Tuple

loss_type = 'ForCausalLM'#
prepare_inputs_for_call(image_grid_thw: Optional[Union[Array, ndarray, bool, number]] = None, video_grid_thw: Optional[Union[Array, ndarray, bool, number]] = None, image_max_grid_size: int = None, video_max_grid_size: int = None, drop_ids: bool = True, **others)[source]#

Prepares keyword arguments before passing them to the module’s __call__ method.

This base implementation simply returns the kwargs as is. Subclasses can override this to modify or add arguments as needed (e.g., for generation).

Parameters

**kwargs – The keyword arguments intended for __call__.

Returns

The prepared keyword arguments.

Return type

dict

prepare_inputs_for_generation(input_ids, max_length, past_key_values=None, attention_mask=None, inputs_embeds=None, position_ids=None, pixel_values=None, pixel_values_videos=None, image_grid_thw=None, video_grid_thw=None, **kwargs)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds

Returns

A dictionary of the past_key_values, attention_mask and position ids

update_inputs_for_generation(model_outputs, model_kwargs)[source]#
class easydel.__init__.Qwen2VLModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Qwen3Config(vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, head_dim=128, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, **kwargs)[source]#

Bases: EasyDeLBaseConfig

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'qwen3'#
class easydel.__init__.Qwen3ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Qwen3 model with a Causal Language Modeling head.

This model consists of the base Qwen3 transformer (Qwen3Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.

config#

Configuration object for the model.

Type

Qwen3Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Qwen3 transformer model.

Type

Qwen3Model

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

class easydel.__init__.Qwen3ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Qwen3 model with a Sequence Classification head.

This model consists of the base Qwen3 transformer (Qwen3Model) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token or a pooled representation) to the number of classes for classification.

config#

Configuration object for the model.

Type

Qwen3Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Qwen3 transformer model.

Type

Qwen3Model

score#

The linear layer for classification.

Type

ParallelLinear

class easydel.__init__.Qwen3Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Qwen3 model transformer.

This class represents the core transformer architecture of the Qwen3 model, consisting of an embedding layer, multiple Qwen3DecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Qwen3Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[Qwen3DecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

class easydel.__init__.Qwen3MoeConfig(vocab_size=151936, hidden_size=2048, intermediate_size=6144, num_hidden_layers=24, num_attention_heads=32, num_key_value_heads=4, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, decoder_sparse_step=1, moe_intermediate_size=768, num_experts_per_tok=8, num_experts=128, norm_topk_prob=False, output_router_logits=False, router_aux_loss_coef=0.001, mlp_only_layers=None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'qwen3_moe'#
class easydel.__init__.Qwen3MoeForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Qwen3Moe model with a Causal Language Modeling head.

This model consists of the base Qwen3Moe transformer (Qwen3MoeModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.

config#

Configuration object for the model.

Type

Qwen3MoeConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Qwen3Moe transformer model.

Type

Qwen3MoeModel

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

class easydel.__init__.Qwen3MoeForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Qwen3Moe model with a Sequence Classification head.

This model consists of the base Qwen3Moe transformer (Qwen3MoeModel) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token or a pooled representation) to the number of classes for classification.

config#

Configuration object for the model.

Type

Qwen3MoeConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Qwen3Moe transformer model.

Type

Qwen3MoeModel

score#

The linear layer for classification.

Type

ParallelLinear

class easydel.__init__.Qwen3MoeModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Qwen3Moe model transformer.

This class represents the core transformer architecture of the Qwen3Moe model, consisting of an embedding layer, multiple Qwen3MoeDecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Qwen3MoeConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[Qwen3MoeDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

class easydel.__init__.RewardConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 5e-05, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: ~typing.Optional[int] = 1024, max_training_steps: tp.Optional[int] = None, model_name: str = 'RewardTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = False, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, disable_dropout: bool = True, dataset_num_proc: ~typing.Optional[int] = None, center_rewards_coefficient: ~typing.Optional[float] = 0.1)[source]#

Bases: TrainingArguments

Configuration class for the [RewardTrainer].

Parameters
  • model_name (str) – The name of the model. Defaults to “RewardTrainer”.

  • max_length (int, optional) – Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the limit. Defaults to 1024.

  • disable_dropout (bool, optional) – Whether to disable dropout in the model. Defaults to True.

  • dataset_num_proc (int, optional) – Number of processes to use for processing the dataset. Defaults to None.

  • center_rewards_coefficient (float, optional) – Coefficient to incentivize the reward model to output mean-zero rewards. Defaults to 0.1.

  • remove_unused_columns (bool, optional) – Whether to remove the columns that are not used by the model’s forward pass. Can be True only if the dataset is pretokenized. Defaults to False.

center_rewards_coefficient: Optional[float] = 0.1#
dataset_num_proc: Optional[int] = None#
disable_dropout: bool = True#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

max_sequence_length: Optional[int] = 1024#
model_name: str = 'RewardTrainer'#
remove_unused_columns: bool = False#
replace(**kwargs)#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

class easydel.__init__.RewardTrainer(arguments: RewardConfig, processing_class: Any, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, data_collator: Optional[RewardDataCollatorWithPadding] = None)[source]#

Bases: Trainer

This trainer extends the Trainer and provides functionalities.

configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method prepares the functions that will be used during training and evaluation. It sets up sharding for the model parameters and optimizer state, JIT-compiles the training and evaluation functions with the appropriate static arguments and sharding constraints, and also sets up the checkpoint manager.

Returns

An object containing:
  • sharded_training_step_function: The compiled training step function.

  • sharded_evaluation_step_function: The compiled evaluation step function.

  • mesh: The device mesh used for computation.

  • checkpoint_manager: The checkpointer for saving/loading model state.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length, truncation_mode='keep_end')[source]#

Creates a collate/collect function to process batches of data for training or evaluation.

This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.

Returns

A function that takes a batch (list of dicts) and returns a processed dict of arrays.

Return type

tp.Callable

class easydel.__init__.RobertaConfig(vocab_size=50265, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act='gelu', hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=514, type_vocab_size=1, initializer_range=0.02, layer_norm_eps=1e-05, pad_token_id=1, bos_token_id=0, eos_token_id=2, position_embedding_type='absolute', use_cache=True, classifier_dropout=None, gradient_checkpointing='nothing_saveable', **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information. :param vocab_size: Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by

the inputs_ids passed when calling RobertaModel.

Parameters
  • hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.

  • num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.

  • intermediate_size (int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • hidden_act (str or function, optional, defaults to "gelu") – The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported.

  • hidden_dropout_prob (float, optional, defaults to 0.1) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • attention_probs_dropout_prob (float, optional, defaults to 0.1) – The dropout ratio for the attention probabilities.

  • max_position_embeddings (int, optional, defaults to 514) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).

  • type_vocab_size (int, optional, defaults to 1) – The vocabulary size of the token_type_ids passed when calling RobertaModel.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • position_embedding_type (str, optional, defaults to "absolute") – Type of position embedding. Choose one of "absolute", "relative_key", "relative_key_query". For positional embeddings use "absolute". For more information on "relative_key", please refer to [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). For more information on "relative_key_query", please refer to Method 4 in [Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • classifier_dropout (float, optional) – The dropout ratio for the classification head.

  • gradient_checkpointing (str, optional, defaults to "nothing_saveable") – What to save during gradient checkpointing. Choose one of "nothing_saveable", "first_half_saveable", "full_saveable".

attach_custom_arguments(gradient_checkpointing='nothing_saveable', **kwargs)[source]#

Attach custom arguments to the configuration.

This method allows attaching additional custom arguments to the configuration that weren’t part of the initial configuration.

Parameters
  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – What to save during gradient checkpointing. Choose one of “nothing_saveable”, “first_half_saveable”, “full_saveable”.

  • **kwargs – Additional custom arguments to be attached to the configuration.

get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'roberta'#
class easydel.__init__.RobertaForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.RobertaForMultipleChoice(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.RobertaForQuestionAnswering(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.RobertaForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.RobertaForTokenClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.SFTConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 2e-05, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'SFTTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, dataset_text_field: ~typing.Optional[str] = None, add_special_tokens: bool = False, packing: bool = False, dataset_num_proc: ~typing.Optional[int] = None, dataset_batch_size: int = 1000, dataset_kwargs: ~typing.Optional[dict[str, typing.Any]] = None, eval_packing: ~typing.Optional[bool] = None, num_of_sequences: int = 1024, chars_per_token: float = 3.6)[source]#

Bases: TrainingArguments

Configuration class for the [SFTTrainer].

Parameters
  • model_name (str) – The name of the model. Defaults to “SFTTrainer”.

  • dataset_text_field (str, optional) – Name of the text field of the dataset. If provided, the trainer will automatically create a [ConstantLengthDataset] based on dataset_text_field. Defaults to None.

  • packing (bool, optional) – Controls whether the [ConstantLengthDataset] packs the sequences of the dataset. Defaults to False.

  • learning_rate (float, optional) – Initial learning rate for [AdamW] optimizer. The default value replaces that of [~transformers.TrainingArguments]. Defaults to 2e-5.

  • dataset_num_proc (int, optional) – Number of processes to use for processing the dataset. Only used when packing=False. Defaults to None.

  • dataset_batch_size (int, optional) – Number of examples to tokenize per batch. If dataset_batch_size <= 0 or dataset_batch_size is None, tokenizes the full dataset as a single batch. Defaults to 1000.

  • dataset_kwargs (dict[str, Any], optional) – Dictionary of optional keyword arguments to pass when creating packed or non-packed datasets. Defaults to None.

  • eval_packing (bool, optional) – Whether to pack the eval dataset. If None, uses the same value as packing. Defaults to None.

  • num_of_sequences (int, optional) – Number of sequences to use for the [ConstantLengthDataset]. Defaults to 1024.

  • chars_per_token (float, optional) – Number of characters per token to use for the [ConstantLengthDataset]. See [chars_token_ratio](huggingface/trl) for more details. Defaults to 3.6.

add_special_tokens: bool = False#
chars_per_token: float = 3.6#
dataset_batch_size: int = 1000#
dataset_kwargs: Optional[dict[str, Any]] = None#
dataset_num_proc: Optional[int] = None#
dataset_text_field: Optional[str] = None#
eval_packing: Optional[bool] = None#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

learning_rate: float = 2e-05#
model_name: str = 'SFTTrainer'#
num_of_sequences: int = 1024#
packing: bool = False#
replace(**kwargs)#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

class easydel.__init__.SFTTrainer(arguments: SFTConfig, processing_class: Any, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, formatting_func: Optional[Callable] = None, data_collator: Optional[DataCollatorForCompletionOnlyLM] = None)[source]#

Bases: Trainer

Trainer class for Supervised Fine-Tuning (SFT) of language models.

This trainer extends the Trainer and provides functionalities specific to supervised fine-tuning tasks.

class easydel.__init__.SamplingParams(max_tokens: int = 16, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repetition_penalty: float = 1.0, temperature: float = 0.0, top_p: float = 1.0, top_k: int = 0, min_p: float = 0.0, suppress_tokens: list[int] = <factory>)[source]#

Bases: object

Parameters controlling the sampling process during text generation.

max_tokens#

The maximum number of tokens to generate (excluding the prompt). Defaults to 16.

Type

int

presence_penalty#

Penalty applied to the logits of tokens already present in the generated sequence. Positive values discourage repetition. Defaults to 0.0.

Type

float

frequency_penalty#

Penalty applied to the logits of tokens based on their frequency in the generated sequence so far. Positive values discourage verbatim repetition. Defaults to 0.0.

Type

float

repetition_penalty#

Multiplicative penalty applied to the logits of previously seen tokens. Values > 1.0 discourage repetition, < 1.0 encourage it. Defaults to 1.0.

Type

float

temperature#

Controls the randomness of the sampling. Higher values (e.g., > 1.0) make the distribution flatter (more random), lower values (e.g., < 1.0) make it peakier (more deterministic). A value of 0.0 effectively becomes greedy sampling. Defaults to 0.0.

Type

float

top_p#

Nucleus sampling threshold. If set to a value < 1.0, only the most probable tokens with a cumulative probability exceeding top_p are considered for sampling. Defaults to 1.0 (no nucleus sampling).

Type

float

top_k#

Top-k sampling threshold. If set to a value > 0, only the top_k most probable tokens are considered for sampling. Defaults to 0 (no top-k sampling).

Type

int

min_p#

Minimum probability threshold. Filters out tokens with probability less than min_p. Defaults to 0.0 (no minimum probability filtering).

Type

float

suppress_tokens#

A list of token IDs that should be completely suppressed (their logits set to -inf) during generation. Defaults to an empty list.

Type

list[int]

frequency_penalty: float = 0.0#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

get_logits_processor()[source]#

Constructs a LogitsProcessorList containing the configured logits processors.

Logits processors modify the logits directly, often used for applying penalties (presence, frequency, repetition) or suppressing specific tokens.

Returns

A LogitsProcessorList containing the enabled logits processors based on the sampling parameters.

get_logits_warper()[source]#

Constructs a LogitsProcessorList containing the configured logits warpers.

Logits warpers modify the probability distribution derived from logits, typically used for techniques like temperature scaling, top-k, top-p, and min-p sampling.

Returns

A LogitsProcessorList containing the enabled logits warpers based on the sampling parameters.

max_tokens: int = 16#
min_p: float = 0.0#
presence_penalty: float = 0.0#
repetition_penalty: float = 1.0#
replace(**kwargs)#
suppress_tokens: list[int]#
temperature: float = 0.0#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

top_k: int = 0#
top_p: float = 1.0#
class easydel.__init__.SiglipConfig(text_config=None, vision_config=None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

[SiglipConfig] is the configuration class to store the configuration of a [SiglipModel]. It is used to instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • text_config (dict, optional) – Dictionary of configuration options used to initialize [SiglipTextConfig].

  • vision_config (dict, optional) – Dictionary of configuration options used to initialize [SiglipVisionConfig].

  • kwargs (optional) – Dictionary of keyword arguments.

classmethod from_text_vision_configs(text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs)[source]#

Instantiate a [SiglipConfig] (or a derived class) from siglip text model configuration and siglip vision model configuration.

Returns

An instance of a configuration object

Return type

[SiglipConfig]

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the SigLIP model parameters for use with distributed training.

These rules define how parameters should be partitioned across multiple devices when using techniques like Fully Sharded Data Parallelism (FSDP), Sharded Parallelism (SP), and Tensor Parallelism (TP). Each rule consists of a regex pattern for matching parameter names and a corresponding PartitionSpec.

Returns

A tuple of tuples where each inner tuple contains:
  • A regex pattern matching parameter names

  • A PartitionSpec object specifying how to partition matching parameters

Return type

tuple

model_type: str = 'siglip'#

The model type identifier used to determine which model configuration this represents. This is set to “siglip” to identify this as the main configuration for the SigLIP model.

sub_configs: dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.siglip.configuration_siglip.SiglipTextConfig'>, 'vision_config': <class 'easydel.__init__.modules.siglip.configuration_siglip.SiglipVisionConfig'>}#

A dictionary that maps configuration keys to their respective configuration classes. This enables the SiglipConfig to manage both text and vision components through separate configurations while maintaining them as part of a single unified model.

class easydel.__init__.SiglipForImageClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.SiglipModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

get_image_features(pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False) Union[Array, ndarray, bool, number][source]#
get_text_features(input_ids: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, position_ids: Optional[Union[Array, ndarray, bool, number]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None) Union[Array, ndarray, bool, number][source]#
class easydel.__init__.SiglipTextConfig(vocab_size=32000, hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, max_position_embeddings=64, hidden_act='gelu_pytorch_tanh', layer_norm_eps=1e-06, attention_dropout=0.0, pad_token_id=1, bos_token_id=49406, eos_token_id=49407, projection_size=None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [SiglipModel].

  • hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.

  • max_position_embeddings (int, optional, defaults to 64) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).

  • hidden_act (str or function, optional, defaults to “gelu_pytorch_tanh”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “selu” and “gelu_new” “quick_gelu” are supported.

  • layer_norm_eps (float, optional, defaults to 1e-06) – The epsilon used by the layer normalization layers.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • pad_token_id (int, optional, defaults to 1) – The id of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 49406) – The id of the beginning-of-sequence token in the vocabulary.

  • eos_token_id (int, optional, defaults to 49407) – The id of the end-of-sequence token in the vocabulary.

  • projection_size (int, optional, defaults to hidden_size) – The size of the projection head.

Example:

```python >>> from transformers import SiglipTextConfig, SiglipTextModel

>>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipTextConfig()
>>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
base_config_key: str = 'text_config'#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'siglip_text_model'#
class easydel.__init__.SiglipTextModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.SiglipVisionConfig(hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, num_channels=3, image_size=224, patch_size=16, hidden_act='gelu_pytorch_tanh', layer_norm_eps=1e-06, attention_dropout=0.0, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_channels (int, optional, defaults to 3) – Number of channels in the input images.

  • image_size (int, optional, defaults to 224) – The size (resolution) of each image.

  • patch_size (int, optional, defaults to 16) – The size (resolution) of each patch.

  • hidden_act (str or function, optional, defaults to “gelu_pytorch_tanh”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “selu” and “gelu_new” “quick_gelu” are supported.

  • layer_norm_eps (float, optional, defaults to 1e-06) – The epsilon used by the layer normalization layers.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

base_config_key: str = 'vision_config'#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'siglip_vision_model'#
class easydel.__init__.SiglipVisionModel(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.__init__.StableLmConfig(vocab_size=50304, intermediate_size=6912, hidden_size=2560, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act='silu', max_position_embeddings=4096, initializer_range=0.02, layer_norm_eps=1e-05, use_cache=True, tie_word_embeddings=False, rope_theta=10000, rope_scaling=None, use_qkv_bias=False, qk_layernorm=False, use_parallel_residual=False, hidden_dropout=0.0, attention_dropout=0.0, partial_rotary_factor=0.25, bos_token_id=0, eos_token_id=0, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50304) – Vocabulary size of the StableLM model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [~easydel.modules.StableLmModel].

  • hidden_size (int, optional, defaults to 2560) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 6912) – Dimensionality of the “intermediate” (often named feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 32) – Number of key-value heads for each attention layer in the Transformer encoder.

  • hidden_act (str or tp.Callable, optional, defaults to “silu”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 4096) – The maximum sequence length that this model might ever be used with.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models).

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (int, optional, defaults to 10000) – The theta value for the rotary position embeddings.

  • rope_scaling (str, optional) – The scaling to use for the rotary position embeddings.

  • qk_layernorm (bool, optional, defaults to False) – Whether to use layer normalization on the queries and keys in the attention layer.

  • use_parallel_residual (bool, optional, defaults to False) – Whether to use a parallel residual connection in the attention layer.

  • hidden_dropout (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • partial_rotary_factor (float, optional, defaults to 0.25) – The factor to scale the partial rotary embeddings by.

  • bos_token_id (int, optional, defaults to 0) – The id for the beginning of stream token.

  • eos_token_id (int, optional, defaults to 0) – The id for the end of stream token.

  • bits (int, optional) – The number of bits to quantize the model to. If None, the model is not quantized.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – What to save during gradient checkpointing. Choose one of “nothing_saveable”, “first_half_saveable”, “full_saveable”.

get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size specifically for frequency-based position embeddings.

If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for frequency encoding.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size specifically for mask-based position embeddings.

If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for mask encoding.

Return type

int

model_type: str = 'stablelm'#
class easydel.__init__.StableLmForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

StableLM model with a Causal Language Modeling (CLM) head.

This class wraps the base StableLmModel and adds a linear layer (language model head) to predict the next token logits.

config#

Configuration object for the model.

Type

StableLmConfig

model#

The base StableLM model.

Type

StableLmModel

lm_head#

The language model head (linear layer).

Type

ParallelLinear

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.__init__.StableLmModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base StableLM transformer model.

This class implements the core transformer architecture, including embedding layers, decoder layers, and final normalization.

config#

Configuration object for the model.

Type

StableLmConfig

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

nn.List[StableLmDecoderLayer]

norm#

Final layer normalization.

Type

nn.LayerNorm

gradient_checkpointing#

Gradient checkpointing strategy.

Type

str

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

property frequencies#

Cached property for precomputed rotary frequencies.

class easydel.__init__.TaskType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enumeration defining different model task types supported by the registry.

CAUSAL_LM#

Causal Language Modeling (e.g., GPT-style models).

VISION_LM#

Vision Language Modeling (models combining vision and text).

IMAGE_TEXT_TO_TEXT#

Models that take image and text input to produce text output.

BASE_MODULE#

Basic, potentially abstract, modules.

BASE_VISION#

Basic vision modules.

SEQUENCE_TO_SEQUENCE#

Sequence-to-sequence tasks (e.g., translation, summarization).

SPEECH_SEQUENCE_TO_SEQUENCE#

Speech-to-text or other speech sequence tasks.

ZERO_SHOT_IMAGE_CLASSIFICATION#

Image classification without task-specific training.

SEQUENCE_CLASSIFICATION#

Classifying entire sequences (e.g., sentiment analysis).

AUDIO_CLASSIFICATION#

Classifying audio data.

IMAGE_CLASSIFICATION#

Classifying images.

AUDIO_CLASSIFICATION = 'audio-classification'#
BASE_MODULE = 'base-module'#
BASE_VISION = 'vision-module'#
CAUSAL_LM = 'causal-language-model'#
IMAGE_CLASSIFICATION = 'image-classification'#
IMAGE_TEXT_TO_TEXT = 'image-text-to-text'#
SEQUENCE_CLASSIFICATION = 'sequence-classification'#
SEQUENCE_TO_SEQUENCE = 'sequence-to-sequence'#
SPEECH_SEQUENCE_TO_SEQUENCE = 'speech-sequence-to-sequence'#
VISION_LM = 'vision-language-model'#
ZERO_SHOT_IMAGE_CLASSIFICATION = 'zero-shot-image-classification'#
class easydel.__init__.Trainer(arguments: tp.Optional[TrainingArguments] = None, model_state: tp.Optional[EasyDeLState] = None, model: tp.type[EasyDeLBaseModule] = None, dataset_train: tp.Optional[Dataset] = None, dataset_eval: tp.Optional[Dataset] = None, data_collator: tp.Optional[tp.Callable] = None, finetune: bool = True, checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]] = None, **deprecated_kwargs)[source]#

Bases: BaseTrainer

configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method prepares the functions that will be used during training and evaluation. It sets up sharding for the model parameters and optimizer state, JIT-compiles the training and evaluation functions with the appropriate static arguments and sharding constraints, and also sets up the checkpoint manager.

Returns

An object containing:
  • sharded_training_step_function: The compiled training step function.

  • sharded_evaluation_step_function: The compiled evaluation step function.

  • mesh: The device mesh used for computation.

  • checkpoint_manager: The checkpointer for saving/loading model state.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Creates a collate/collect function to process batches of data for training or evaluation.

This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.

Returns

A function that takes a batch (list of dicts) and returns a processed dict of arrays.

Return type

tp.Callable

eval(model_state: EasyDeLState) Iterator[dict][source]#

Evaluates the model using the provided model state.

This method iterates over the evaluation dataset, performs forward passes, calculates evaluation metrics, logs the metrics, and yields the metrics for each evaluation step.

Parameters

model_state (EasyDeLState) – The state of the model (including parameters and configuration) to be used for evaluation.

Yields

Iterator[dict] – An iterator yielding a dictionary of evaluation metrics for each evaluation step.

Raises

AssertionError – If the evaluation dataloader is not set.

train() TrainerOutput[source]#

Executes the complete training process.

This method sets up initial metrics and logging, runs the training loop, and finalizes training. It calls the training hook at the beginning and returns a TrainerOutput object at the end.

Returns

An object containing the final training state, metrics, and any additional outputs.

Return type

TrainerOutput

class easydel.__init__.TrainingArguments(auto_shard_states: 'bool' = True, aux_loss_enabled: 'bool' = False, backend: 'tp.Optional[str]' = None, clip_grad: 'tp.Optional[float]' = None, custom_scheduler: 'tp.Optional[tp.Callable[[int], tp.Any]]' = None, dataloader_num_workers: 'tp.Optional[int]' = 0, dataloader_pin_memory: 'tp.Optional[bool]' = False, do_eval: 'bool' = False, do_last_save: 'bool' = True, do_train: 'bool' = True, eval_batch_size: 'tp.Optional[int]' = None, evaluation_steps: 'tp.Optional[int]' = None, extra_optimizer_kwargs: 'dict' = <factory>, frozen_parameters: 'tp.Optional[str]' = None, gradient_accumulation_steps: 'int' = 1, ids_to_pop_from_dataset: 'tp.Optional[tp.List[str]]' = <factory>, is_fine_tuning: 'bool' = True, init_tx: 'bool' = True, jax_distributed_config: 'tp.Optional[dict]' = None, learning_rate: 'float' = 5e-05, learning_rate_end: 'tp.Optional[float]' = None, log_all_workers: 'bool' = False, log_grad_norms: 'bool' = True, report_metrics: 'bool' = True, log_steps: 'int' = 10, loss_config: 'tp.Optional[LossConfig]' = None, low_mem_usage: 'bool' = True, max_evaluation_steps: 'tp.Optional[int]' = None, max_sequence_length: 'tp.Optional[int]' = 4096, max_training_steps: 'tp.Optional[int]' = None, model_name: 'str' = 'BaseTrainer', model_parameters: 'tp.Optional[dict]' = None, metrics_to_show_in_rich_pbar: 'tp.Optional[tp.List[str]]' = None, num_train_epochs: 'int' = 10, offload_dataset: 'bool' = False, offload_device_type: 'str' = 'cpu', offload_device_index: 'int' = 0, optimizer: 'AVAILABLE_OPTIMIZERS' = <EasyDeLOptimizers.ADAMW: 'adamw'>, performance_mode: 'bool' = False, pruning_module: 'tp.Any' = None, process_zero_is_admin: 'bool' = True, progress_bar_type: "tp.Literal['tqdm', 'rich', 'json']" = 'tqdm', remove_ckpt_after_load: 'bool' = False, remove_unused_columns: 'bool' = True, report_steps: 'int' = 5, save_directory: 'str' = 'EasyDeL-Checkpoints', save_optimizer_state: 'bool' = True, save_steps: 'tp.Optional[int]' = None, save_total_limit: 'tp.Optional[int]' = None, scheduler: 'AVAILABLE_SCHEDULERS' = <EasyDeLSchedulers.NONE: 'None'>, sparsify_module: 'bool' = False, sparse_module_type: 'AVAILABLE_SPARSE_MODULE_TYPES' = 'bcoo', state_apply_fn_kwarguments_to_model: 'tp.Optional[dict]' = None, step_partition_spec: 'PartitionSpec' = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: 'tp.Optional[int]' = None, shuffle_train_dataset: 'bool' = True, total_batch_size: 'int' = 32, training_time_limit: 'tp.Optional[str]' = None, train_on_inputs: 'bool' = True, truncation_mode: "tp.Literal['keep_end', 'keep_start']" = 'keep_end', tx_mu_dtype: 'tp.Optional[jnp.dtype]' = None, track_memory: 'bool' = False, use_data_collactor: 'bool' = True, use_wandb: 'bool' = True, verbose: 'bool' = True, wandb_entity: 'tp.Optional[str]' = None, warmup_steps: 'int' = 0, weight_decay: 'float' = 0.01, weight_distribution_pattern: 'str' = '.*?(layernorm|norm).*?', weight_distribution_log_steps: 'int' = 0)[source]#

Bases: object

auto_shard_states: bool = True#
aux_loss_enabled: bool = False#
backend: Optional[str] = None#
clip_grad: Optional[float] = None#
custom_scheduler: Optional[Callable[[int], Any]] = None#
dataloader_num_workers: Optional[int] = 0#
dataloader_pin_memory: Optional[bool] = False#
do_eval: bool = False#
do_last_save: bool = True#
do_train: bool = True#
ensure_checkpoint_path()[source]#

Creates the checkpoint directory if it doesn’t exist.

ensure_training_time_limit(time_passed)[source]#
eval_batch_size: Optional[int] = None#
evaluation_steps: Optional[int] = None#
extra_optimizer_kwargs: dict#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

frozen_parameters: Optional[str] = None#
get_optimizer_and_scheduler(steps: Optional[int] = None)[source]#

Returns the configured optimizer and learning rate scheduler.

Parameters

steps (tp.Optional[int]) – The number of training steps. If not provided, uses the value from self.optimizer_kwargs.

Returns

A tuple containing the optimizer and scheduler.

Return type

tuple

get_path() Path[source]#

Returns the path to the checkpoint directory.

Returns

The path to the checkpoint directory.

Return type

Path

get_streaming_checkpointer()[source]#

Returns the checkpoint manager, responsible for saving model checkpoints.

Returns

The checkpoint manager.

Return type

CheckpointManager

get_tensorboard()[source]#

Returns the TensorBoard SummaryWriter, used for logging metrics.

Returns

The TensorBoard SummaryWriter.

Return type

flax.metrics.tensorboard.SummaryWriter

get_wandb_init()[source]#

Initializes Weights & Biases for experiment tracking if enabled.

Returns

The WandB run object if initialized, else None.

Return type

tp.Optional[wandb.sdk.wandb_run.Run]

gradient_accumulation_steps: int = 1#
ids_to_pop_from_dataset: Optional[List[str]]#
init_tx: bool = True#
is_fine_tuning: bool = True#
property is_process_zero#
jax_distributed_config: Optional[dict] = None#
learning_rate: float = 5e-05#
learning_rate_end: Optional[float] = None#
log_all_workers: bool = False#
log_grad_norms: bool = True#
log_metrics(metrics: Any, step: int, log_as: Optional[Literal['summary', 'config']] = None)[source]#

Logs training metrics to Weights & Biases and/or TensorBoard.

Parameters
  • metrics (tp.Dict[str, tp.Union[float, tp.List, tp.Tuple, np.ndarray, 'jnp.ndarray', 'torch.Tensor']]) – A dictionary where keys are metric names and values are metric values.

  • step (int) – The current training step or iteration.

log_steps: int = 10#
log_weight_distribution(state, step: int)[source]#
loss_config: Optional[LossConfig] = None#
low_mem_usage: bool = True#
max_evaluation_steps: Optional[int] = None#
max_sequence_length: Optional[int] = 4096#
max_training_steps: Optional[int] = None#
metrics_to_show_in_rich_pbar: Optional[List[str]] = None#
model_name: str = 'BaseTrainer'#
model_parameters: Optional[dict] = None#
num_train_epochs: int = 10#
offload_dataset: bool = False#
property offload_device#
offload_device_index: int = 0#
offload_device_type: str = 'cpu'#
optimizer: Literal['adafactor', 'lion', 'adamw', 'rmsprop'] = 'adamw'#
performance_mode: bool = False#
process_zero_is_admin: bool = True#
progress_bar_type: Literal['tqdm', 'rich', 'json'] = 'tqdm'#
pruning_module: Any = None#
remove_ckpt_after_load: bool = False#
remove_unused_columns: bool = True#
replace(**kwargs)#
report_metrics: bool = True#
report_steps: int = 5#
save_directory: str = 'EasyDeL-Checkpoints'#
save_optimizer_state: bool = True#
save_steps: Optional[int] = None#
save_total_limit: Optional[int] = None#
scheduler: Literal['linear', 'cosine', 'none'] = 'None'#
shuffle_train_dataset: bool = True#
sparse_module_type: Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo'#
sparsify_module: bool = False#
state_apply_fn_kwarguments_to_model: Optional[dict] = None#
step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp')#
step_start_point: Optional[int] = None#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

total_batch_size: int = 32#
track_memory: bool = False#
train_on_inputs: bool = True#
training_time_limit: Optional[str] = None#
property training_time_seconds: int#
truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end'#
tx_mu_dtype: Optional[dtype] = None#
use_data_collactor: bool = True#
use_wandb: bool = True#
verbose: bool = True#
wandb_entity: Optional[str] = None#
warmup_steps: int = 0#
weight_decay: float = 0.01#
weight_distribution_log_steps: int = 0#
weight_distribution_pattern: str = '.*?(layernorm|norm).*?'#
class easydel.__init__.WhisperConfig(vocab_size=51865, num_mel_bins=80, encoder_layers=4, encoder_attention_heads=6, decoder_layers=4, decoder_attention_heads=6, decoder_ffn_dim=1536, encoder_ffn_dim=1536, encoder_layerdrop=0.0, decoder_layerdrop=0.0, decoder_start_token_id=50257, use_cache=True, is_encoder_decoder=True, activation_function='gelu', d_model=384, dropout=0.0, attention_dropout=0.0, activation_dropout=0.0, init_std=0.02, scale_embedding=False, max_source_positions=1500, max_target_positions=448, pad_token_id=50256, bos_token_id=50256, eos_token_id=50256, suppress_tokens=None, begin_suppress_tokens=[220, 50256], use_weighted_layer_sum=False, classifier_proj_size=256, apply_spec_augment=False, mask_time_prob=0.05, mask_time_length=10, mask_time_min_masks=2, mask_feature_prob=0.0, mask_feature_length=10, mask_feature_min_masks=0, median_filter_width=7, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 51865) – Vocabulary size of the Whisper model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [~easydel.modules.WhisperModel].

  • num_mel_bins (int, optional, defaults to 80) – Number of mel bins used by the feature extractor.

  • encoder_layers (int, optional, defaults to 6) – Number of encoder layers.

  • encoder_attention_heads (int, optional, defaults to 4) – Number of attention heads for each attention layer in the Transformer encoder.

  • decoder_layers (int, optional, defaults to 6) – Number of decoder layers.

  • decoder_attention_heads (int, optional, defaults to 4) – Number of attention heads for each attention layer in the Transformer decoder.

  • decoder_ffn_dim (int, optional, defaults to 1536) – Dimensionality of the decoder feed-forward network (FFN) layer.

  • encoder_ffn_dim (int, optional, defaults to 1536) – Dimensionality of the encoder feed-forward network (FFN) layer.

  • encoder_layerdrop (float, optional, defaults to 0.0) – The LayerDrop probability for the encoder. See the [LayerDrop paper](https://arxiv.org/abs/1909.11556) for more details.

  • decoder_layerdrop (float, optional, defaults to 0.0) – The LayerDrop probability for the decoder. See the [LayerDrop paper](https://arxiv.org/abs/1909.11556) for more details.

  • d_model (int, optional, defaults to 256) – Dimensionality of the layers and the pooler layer.

  • activation_function (str, optional, defaults to “gelu”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “silu” and “gelu_new” are supported.

  • dropout (float, optional, defaults to 0.1) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • activation_dropout (float, optional, defaults to 0.0) – The dropout ratio for activations inside the fully connected layer.

  • init_std (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • scale_embedding (bool, optional, defaults to False) – Scale embeddings by dividing by sqrt(d_model).

  • max_source_positions (int, optional, defaults to 1500) – The maximum sequence length allowed for the source text input to the model. tp.Any longer inputs will be truncated.

  • max_target_positions (int, optional, defaults to 448) – The maximum sequence length allowed for the target text input to the model. tp.Any longer inputs will be truncated.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models).

  • apply_spec_augment (bool, optional, defaults to False) – Whether to apply SpecAugment data augmentation.

  • mask_time_prob (float, optional, defaults to 0.05) – Propability of each feature vector along the time axis to be chosen as the start of the vector span to be masked. Approximately mask_time_prob * sequence_length // mask_time_length feature vectors will be masked along the time axis. This is only relevant if apply_spec_augment is set to True.

  • mask_time_length (int, optional, defaults to 10) – Length of vector span along the time axis.

  • mask_time_min_masks (int, optional, defaults to 2) – The minimum number of masks of length mask_feature_length generated along the time axis, each time mask, the mask will be filled with floats sampled in (random_lower_bound, random_upper_bound).

  • mask_feature_prob (float, optional, defaults to 0.0) – Propability of each feature vector along the feature axis to be chosen as the start of the vector span to be masked. Approximately mask_time_prob * hidden_size // mask_feature_length feature vectors will be masked along the time axis. This is only relevant if apply_spec_augment is set to True.

  • mask_feature_length (int, optional, defaults to 10) – Length of vector span along the feature axis.

  • mask_feature_min_masks (int, optional, defaults to 0) – The minimum number of masks of length mask_feature_length generated along the feature axis, each time mask, the mask will be filled with floats sampled in (random_lower_bound, random_upper_bound).

  • median_filter_width (int, optional, defaults to 7) – The width of the median filter applied to the mask.

  • bits (int, optional) – The number of bits to quantize the model to. If None, the model is not quantized.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – What to save during gradient checkpointing. Choose one of “nothing_saveable”, “first_half_saveable”, “full_saveable”.

attribute_map: dict[str, str] = {'hidden_size': 'd_model', 'num_attention_heads': 'encoder_attention_heads'}#
get_partition_rules(*args, **kwargs)[source]#

Returns the partition rules for the Whisper model. Arguments are ignored.

Returns

Partition rules.

Return type

tuple

model_type: str = 'whisper'#
class easydel.__init__.WhisperForAudioClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.WhisperForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#

Computes the loss for the model given a batch of inputs and labels.

This method performs a forward pass using the provided batch arguments, then calculates the loss using the determined loss_function. It handles potential label inference (e.g., using input_ids as labels for Causal LM) and default loss configurations.

Parameters
  • labels (tp.Optional[chex.Array], optional) – The target labels. If None and the task is Causal LM, input_ids from the batch might be used. Defaults to None.

  • loss_config (tp.Optional[LossConfig], optional) – Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.

  • loss_kwargs (tp.Optional[tp.Dict], optional) – Additional keyword arguments to pass directly to the loss function. Defaults to None.

  • **batch – Keyword arguments representing the input batch (e.g., input_ids, attention_mask).

Returns

A tuple containing:
  • The model’s output ( Pytree typically including logits, hidden states etc.)

  • A LossMetrics object containing the calculated loss and potentially other metrics.

Return type

tp.Tuple[tp.Any, LossMetrics]

Raises
  • AssertionError – If labels are required for the loss function but are not provided or inferred.

  • AssertionError – If sequence classification loss is used without num_labels in the config.

decode(decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[Array] = None, decoder_attention_mask: Optional[Array] = None, decoder_position_ids: Optional[Array] = None, past_key_values: Optional[TransformerCache] = None, cache_metadata: Optional[TransformerMetadata] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None)[source]#
encode(input_features: Array, attention_mask: Optional[Array] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs)[source]#
generate(input_features, generation_config=None, logits_processor=None, return_timestamps=None, task=None, language=None, is_multilingual=None, **kwargs)[source]#

Generates sequences of token ids for models with a language modeling head.

Parameters
  • input_ids (chex.Array of shape (batch_size, sequence_length)) – The sequence used as a prompt for the generation.

  • generation_config (~generation.GenerationConfig, optional) – The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them. If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [~generation.GenerationConfig]’s default values, whose documentation should be checked to parameterize generation.

  • trace (bool, optional, defaults to True) – Whether to trace generation. Setting trace=False should only be used for debugging and will lead to a considerably slower runtime.

  • logits_processor (`LogitsProcessorList `, optional) – Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.

  • kwargs (tp.Dict[str, Any], optional) – Ad hoc parametrization of generate_config and/or additional model-specific kwargs that will be forwarded to the forward function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with decoder_.

Returns

[~utils.ModelOutput].

loss_type = 'ForCausalLM'#
prepare_inputs_for_generation(decoder_input_ids, max_length, attention_mask: Optional[Array] = None, decoder_attention_mask: Optional[Array] = None, encoder_outputs=None, **kwargs)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds

Returns

A dictionary of the past_key_values, attention_mask and position ids

update_inputs_for_generation(model_outputs, model_kwargs)[source]#
class easydel.__init__.WhisperTimeStampLogitsProcessor(generate_config, model_config, decoder_input_length)[source]#

Bases: LogitsProcessor

A specialized [LogitsProcessor] tailored for handling timestamp tokens during generation with Whisper-style models used for Automatic Speech Recognition (ASR).

It enforces several constraints specific to timestamp prediction: 1. Suppresses `<|notimestamps|>`: Prevents the model from predicting the token

that indicates the absence of timestamps.

  1. Alternating Tokens: Enforces that text tokens and timestamp tokens generally alternate. If the last generated token was a timestamp, it biases against predicting another timestamp immediately after (unless it’s the very beginning or certain edge cases).

  2. Initial Timestamp Limit: Restricts the maximum value of the first timestamp token predicted using max_initial_timestamp_index.

  3. Timestamp Probability Check: If the total probability mass assigned to all valid timestamp tokens is higher than the probability of the single most likely non-timestamp token, it forces the model to sample a timestamp token by suppressing all non-timestamp tokens.

Note

This processor assumes the existence of specific token IDs related to timestamps (e.g., eos_token_id, no_timestamps_token_id, timestamp_begin) which are typically defined in the model’s generation configuration.

Parameters
  • generate_config – Configuration object containing Whisper-specific generation parameters like eos_token_id, no_timestamps_token_id, is_multilingual, max_initial_timestamp_index.

  • model_config – The model’s configuration (used for vocab_size as a fallback).

  • decoder_input_length – The length of the initial input sequence provided to the decoder (e.g., the prompt length).

class easydel.__init__.Xerxes2Config(vocab_size: int = 256128, hidden_size: int = 4096, intermediate_size: int = 16384, num_hidden_layers: int = 32, num_attention_heads: int = 32, max_position_embeddings: int = 16384, initializer_range: float = 0.02, rms_norm_eps: float = 1e-06, use_cache: bool = True, pad_token_id: int = 0, eos_token_id: int = 1, bos_token_id: int = 2, tie_word_embeddings: bool = False, rope_theta: float = 10000.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, scan_layers: bool = False, q_lora_dim: Optional[int] = 1536, kv_lora_dim: int = 512, qk_rope_head_dim: int = 64, qk_nope_head_dim: int = 128, vhead_dim: int = 128, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 256128) – Vocabulary size of the xerxes model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 16384) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 16) – Number of key and value heads for each attention layer in the Transformer encoder.

  • head_dim (int, optional, defaults to 256) – Dimensionality of the attention head.

  • max_position_embeddings (int, optional, defaults to 6144) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.

  • eos_token_id (int, optional, defaults to 1) – The index of the end of sequence token in the vocabulary.

  • bos_token_id (int, optional, defaults to 2) – The index of the beginning of sequence token in the vocabulary.

  • tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • softmax_scale (float, optional, defaults to 14.9666295471) – softmax scale for attention module.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

  • scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation of the layers.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'xerxes2'#
static rng_keys()[source]#
class easydel.__init__.Xerxes2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

init_cache(batch_size: int, max_length: int)[source]#
class easydel.__init__.Xerxes2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies: Array#

Returns frequency values from the config.

class easydel.__init__.XerxesConfig(vocab_size=256128, hidden_size=4096, intermediate_size=16384, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, head_dim=144, max_position_embeddings=16384, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, num_local_experts: int = 4, xe_moe: bool = True, num_experts_per_tok: int = 2, tie_word_embeddings=False, rope_theta=10000.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, scan_layers: bool = False, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 256128) – Vocabulary size of the xerxes model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 16384) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 16) – Number of key and value heads for each attention layer in the Transformer encoder.

  • head_dim (int, optional, defaults to 256) – Dimensionality of the attention head.

  • max_position_embeddings (int, optional, defaults to 6144) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.

  • eos_token_id (int, optional, defaults to 1) – The index of the end of sequence token in the vocabulary.

  • bos_token_id (int, optional, defaults to 2) – The index of the beginning of sequence token in the vocabulary.

  • tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • softmax_scale (float, optional, defaults to 14.9666295471) – softmax scale for attention module.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

  • scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation of the layers.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'xerxes'#
static rng_keys()[source]#
class easydel.__init__.XerxesForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.XerxesModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

easydel.__init__.auto_pytree(cls=None, meta_fields: Optional[Tuple[str, ...]] = None, json_serializable: bool = True, frozen: bool = False)[source]#

Register a class as a JAX PyTree with performance optimizations.

easydel.__init__.get_modules_by_type(model_type: str, task_type: TaskType) Tuple[Type[EasyDeLBaseConfig], Union[Type[EasyDeLBaseModule], Any]][source]#
The get_modules_by_type function is a helper function that returns the following:
  1. The config class for the model type specified (e.g., LlamaConfig, FalconConfig)

  2. The EasyDeL Model class for the model type specified (e.g., FlaxLlamaForCausalLM, FalconForCausalLM)

easydel.__init__.module_to_huggingface_model(module: ~typing.Any, config: ~typing.Any, base_huggingface_module: ~typing.Any, base_huggingface_module_kwarguments: ~typing.Optional[~typing.Dict] = None, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>, use_meta_torch: bool = True, **kw)[source]#
easydel.__init__.module_to_torch(module: ~typing.Any, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>)[source]#
easydel.__init__.pack_sequences(dataset: Any, max_length: int = 512, pad_token_id: int = 0, reset_position_ids: bool = False, num_proc: Optional[int] = None)[source]#

Pack sequences together with their attention masks and position IDs

# With continuous position IDs packed_dataset = pack_sequences(

dataset, max_length=512, pad_token_id=0, reset_position_ids=False

)

# With reset position IDs for each sequence packed_dataset = pack_sequences(

dataset, max_length=512, pad_token_id=0, reset_position_ids=True

)

# Example output format for a packed sequence with two sequences: # reset_position_ids=False: {

‘input_ids’: [seq1_tokens + [PAD] + seq2_tokens + [PAD] + padding], ‘attention_mask’: [1,1,1,0,1,1,1,0,0,0], ‘position_ids’: [0,1,2,3,4,5,6,7,0,0]

}

# reset_position_ids=True: {

‘input_ids’: [seq1_tokens + [PAD] + seq2_tokens + [PAD] + padding], ‘attention_mask’: [1,1,1,0,1,1,1,0,0,0], ‘position_ids’: [0,1,2,0,0,1,2,0,0,0]

}

Parameters
  • dataset – Dataset containing ‘input_ids’ and ‘attention_mask’

  • max_length – Maximum length of packed sequence

  • pad_token_id – Token ID used for padding

  • reset_position_ids – If True, reset position IDs for each sequence in the pack

Returns

Dataset with packed sequences, attention masks, and position IDs

Return type

packed_dataset

easydel.__init__.register_config(config_type: str, config_field: ConfigType = ConfigType.MODULE_CONFIG) callable#

Decorator factory to register a configuration class.

Parameters
  • config_type (str) – A unique string identifier for this configuration class (e.g., “llama”).

  • config_field (ConfigType) – The category under which to register the config. Defaults to ConfigType.MODULE_CONFIG.

Returns

A decorator that takes the configuration class, registers it,

and enhances its string representation.

Return type

callable

easydel.__init__.register_module(task_type: TaskType, config: EasyDeLBaseConfig, model_type: str, embedding_layer_names: Optional[List[str]] = None, layernorm_names: Optional[List[str]] = None) callable#

Decorator factory to register an EasyDeL module class for a specific task.

Parameters
  • task_type (TaskType) – The task the module is designed for (e.g., TaskType.CAUSAL_LM).

  • config (EasyDeLBaseConfig) – The configuration class associated with this module.

  • model_type (str) – A unique string identifier for this model implementation (e.g., “llama”).

  • embedding_layer_names (tp.Optional[tp.List[str]]) – Optional list of embedding layer names. Defaults to None.

  • layernorm_names (tp.Optional[tp.List[str]]) – Optional list of LayerNorm layer names. Defaults to None.

Returns

A decorator that takes the module class, registers it with its metadata,

and sets internal _model_task and _model_type attributes on the class.

Return type

callable

easydel.__init__.torch_dict_to_easydel_params(state_dict: ~typing.Dict[str, ~typing.Any], *, device: ~typing.Optional[~jaxlib.xla_extension.Device] = None, embedding_layer_names: ~typing.Optional[~typing.List[str]] = None, layernorm_names: ~typing.Optional[~typing.List[str]] = None, shard_fns: ~typing.Optional[~typing.Mapping[tuple, ~typing.Callable]] = None, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>, verbose: bool = True, callback: ~typing.Optional[~typing.Callable[[~jax.Array, tuple], ~jax.Array]] = None, remove_state_dict: bool = False, lm_head_name: ~typing.Optional[str] = None, uses_tie_word_embedding: bool = False, **kwargs) Dict[str, Any][source]#

Convert PyTorch state dict to EasyDel parameter format.

Parameters
  • state_dict – PyTorch state dictionary

  • device – JAX device to use

  • embedding_layer_names – Names of embedding layers

  • layernorm_names – Names of layer normalization layers

  • shard_fns – tp.Mapping of parameter names to sharding functions

  • block_size – Size of processing blocks

  • params_pattern_selection – Regex pattern for parameter selection

  • dtype – Target dtype for parameters

  • verbose – Whether to show progress bar callback: callback for tensors after they are converted to a jax array.

  • remove_state_dict – Whether to delete state_dict after conversion

  • lm_head_name – Name of language model head

  • uses_tie_word_embedding – Whether model uses tied embeddings

  • **kwargs – Additional arguments

Returns

Dictionary of converted parameters in EasyDel format

class easydel.__init__.vInference(model: None, processor_class: None, generation_config: Optional[vInferenceConfig] = None, seed: Optional[int] = None, input_partition_spec: Optional[PartitionSpec] = None, max_new_tokens: int = 512, inference_name: Optional[str] = None)[source]#

Bases: object

Class for performing text generation using a pre-trained language graphdef in EasyDeL.

This class handles the generation process, including initialization, precompilation, and generating text in streaming chunks.

property SEQUENCE_DIM_MAPPING#
adjust_kwargs(input_ids: Array, attention_mask: Optional[Array] = None, **model_kwargs)[source]#
count_tokens(messages: List[Dict[str, str]])[source]#
count_tokens(text: str)
execute_decode(state: SampleState, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, compile_config: Optional[vInferencePreCompileConfig] = None, sampling_params: Optional[SamplingParams] = None, func: Optional[Callable[[Any], SampleState]]) SampleState[source]#
execute_prefill(state: SampleState, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, compile_config: Optional[vInferencePreCompileConfig] = None, sampling_params: Optional[SamplingParams] = None, func: Optional[Callable[[Any], SampleState]]) SampleState[source]#

Executes a single generation step with performance monitoring.

generate(input_ids: Array, attention_mask: Optional[Array] = None, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, sampling_params: Optional[SamplingParams] = None, **model_kwargs) Generator[Union[SampleState, Any], SampleState, SampleState][source]#

Generates text in streaming chunks with comprehensive input adjustment.

Parameters
  • input_ids – Input token IDs as a JAX array

  • attention_mask – Optional attention mask for the input

  • graphstate (nn.GraphState, optional) – in case that you want to update model state for generation.

  • graphother (nn.GraphState, optional) – in case that you want to update model ostate for generation.

  • **model_kwargs – Additional model-specific keyword arguments

Returns

Generator yielding SampleState objects containing generation results and metrics

property inference_name#
classmethod load_inference(path: Union[PathLike, str], model: None, processor_class: None)[source]#
property metrics#
property model#
property model_prefill_length: int#

Calculate the maximum length available for input prefill by subtracting the maximum new tokens from the model’s maximum sequence length.

Returns

The maximum length available for input prefill

Return type

int

Raises

ValueError – If no maximum sequence length configuration is found

precompile(config: vInferencePreCompileConfig)[source]#

Precompiles the generation functions for a given batch size and input length.

This function checks if the generation functions have already been compiled for the given configuration. If not, it compiles them asynchronously and stores them in a cache.

Returns

True if precompilation was successful, False otherwise.

Return type

bool

save_inference(path: Union[PathLike, str])[source]#
property tokenizer#
class easydel.__init__.vInferenceApiServer(inference_map: Union[Dict[str, Any], Any] = None, inference_init_call: Optional[Callable[[], Any]] = None, max_workers: int = 10)[source]#

Bases: object

FastAPI server for serving vInference instances.

This server provides endpoints mimicking the OpenAI API structure for chat completions, liveness/readiness checks, token counting, and listing available models. It handles both streaming and non-streaming requests asynchronously using a thread pool.

async available_inference()[source]#

Lists available models (GET /v1/models).

async chat_completions(request: ChatCompletionRequest)[source]#

Handles chat completion requests (POST /v1/chat/completions).

Validates the request, retrieves the appropriate vInference model, tokenizes the input, and delegates to streaming or non-streaming handlers.

Parameters

request (ChatCompletionRequest) – The incoming request data.

Returns

The generated response, either

a complete JSON object or a streaming event-stream.

Return type

Union[JSONResponse, StreamingResponse]

async completions(request: CompletionRequest)[source]#

Handles completion requests (POST /v1/completions).

Processes the prompt for completion and returns generated text.

Parameters

request (CompletionRequest) – The incoming request data.

Returns

The generated response.

Return type

Union[JSONResponse, StreamingResponse]

async count_tokens(request: CountTokenRequest)[source]#

Token counting endpoint (POST /v1/count_tokens).

fire(host='0.0.0.0', port=11556, metrics_port: Optional[int] = None, log_level='info', ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None)[source]#

Starts the uvicorn server to run the FastAPI application.

Parameters
  • host (str) – The host address to bind to. Defaults to “0.0.0.0”.

  • port (int) – The port to listen on. Defaults to 11556.

  • metrics_port (tp.Optional[int]) – The port for the Prometheus metrics server. If None, defaults to port + 1. Set to -1 to disable.

  • log_level (str) – The logging level for uvicorn. Defaults to “info”.

  • ssl_keyfile (tp.Optional[str]) – Path to the SSL key file for HTTPS.

  • ssl_certfile (tp.Optional[str]) – Path to the SSL certificate file for HTTPS.

async liveness()[source]#

Liveness check endpoint (GET /liveness).

async readiness()[source]#

Readiness check endpoint (GET /readiness).

class easydel.__init__.vInferenceConfig(max_new_tokens: int = 64, streaming_chunks: int = 16, num_return_sequences: Optional[Union[int, Dict[int, int]]] = 1, pad_token_id: Optional[int] = None, bos_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, partition_rules: Optional[Tuple[Tuple[str, Any]]] = None, partition_axis: Optional[PartitionAxis] = None, _loop_rows: Optional[int] = None, sampling_params: Optional[SamplingParams] = None)[source]#

Bases: object

Configuration class for the vInference engine, controlling the overall generation process.

This class holds parameters that define how the generation loop behaves, including length constraints, token control, sharding strategies, and sampling settings.

max_new_tokens#

The maximum number of new tokens to generate, excluding the initial prompt tokens. Defaults to 64.

Type

int

streaming_chunks#

The number of generation steps to compile and execute together as a single unit. Larger chunks can improve performance on TPUs by reducing compilation overhead and kernel launch times, but may increase memory usage. Defaults to 16.

Type

int

num_return_sequences#

The number of sequences to generate and return. Can be: - An integer: Generate this many sequences for all inputs. - A dictionary mapping precompile hash (from vInferencePreCompileConfig)

to an integer: Generate a specific number of sequences based on the compilation configuration. Defaults to 1.

Type

Optional[Union[int, Dict[int, int]]]

pad_token_id#

The token ID used for padding sequences. If None, the model’s default pad token ID might be used, or padding might not be applied.

Type

Optional[int]

bos_token_id#

The token ID representing the beginning-of-sequence. May be used implicitly by the model or generation logic.

Type

Optional[int]

eos_token_id#

The token ID(s) representing the end-of-sequence. Generation stops for a sequence when one of these tokens is sampled. Can be a single integer or a list/tuple of integers.

Type

Optional[Union[int, List[int]]]

partition_rules#

A tuple of custom sharding rules (regex pattern, PartitionSpec) to apply to the model’s parameters and intermediate states (like attention cache). If None, default rules based on partition_axis are generated. Example: ((“.*kernel.*”, PartitionSpec(“fsdp”, None)), …)

Type

Optional[Tuple[Tuple[str, Any]]]

partition_axis#

A PartitionAxis object defining the logical names for sharding axes (e.g., ‘batch’, ‘sequence’, ‘head’). Required if partition_rules is None, used to generate default sharding rules.

Type

Optional[eformer.escale.partition.constraints.PartitionAxis]

_loop_rows#

(Internal) The calculated number of iterations needed in the generation loop based on max_new_tokens and streaming_chunks. Automatically computed in __post_init__.

Type

Optional[int]

sampling_params#

A SamplingParams object containing parameters for the sampling process itself (e.g., temperature, top_k, top_p, repetition penalty). If None, a default SamplingParams instance with max_tokens set to max_new_tokens is created in __post_init__.

Type

Optional[easydel.__init__.inference.utilities.SamplingParams]

bos_token_id: Optional[int] = None#
eos_token_id: Optional[Union[int, List[int]]] = None#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

get_partition_rules(runtime_config: Optional[vInferencePreCompileConfig] = None) Tuple[Tuple[str, Any], ...][source]#

Generates or retrieves the sharding partition rules for the vInference engine.

If self.partition_rules is already set (custom rules provided), it returns them directly.

Otherwise, it constructs a default set of partition rules based on the axis names defined in self.partition_axis. These default rules aim to provide sensible sharding for common model components: - Input sequences (sequences, running_token) are sharded along batch and sequence axes. - Attention masks and position IDs are sharded similarly. - Past key-value states (attention cache), including common quantized formats

(8-bit, NF4), are sharded across batch, key sequence, head, and attention dimension axes.

  • Any parameters/states not matching the specific rules are replicated by default (.*).

Parameters

runtime_config – An optional vInferencePreCompileConfig. Currently unused in the default rule generation but available for potential customization in subclasses or future versions.

Returns

  • A regex pattern (string) matching parameter or state names.

  • A jax.sharding.PartitionSpec defining how the matched items should be sharded.

Return type

A tuple of partition rules. Each rule is a tuple containing

Raises

AssertionError – If self.partition_rules is None and self.partition_axis is also None, as axis names are required to generate default rules.

max_new_tokens: int = 64#
num_return_sequences: Optional[Union[int, Dict[int, int]]] = 1#
pad_token_id: Optional[int] = None#
partition_axis: Optional[PartitionAxis] = None#
partition_rules: Optional[Tuple[Tuple[str, Any]]] = None#
replace(**kwargs)#
sampling_params: Optional[SamplingParams] = None#
streaming_chunks: int = 16#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

class easydel.__init__.vInferencePreCompileConfig(batch_size: Union[int, List[int]] = 1, prefill_length: Optional[Union[int, List[int]]] = None, vision_included: Union[bool, List[bool]] = False, vision_batch_size: Optional[Union[int, List[int]]] = None, vision_channels: Optional[Union[int, List[int]]] = None, vision_height: Optional[Union[int, List[int]]] = None, vision_width: Optional[Union[int, List[int]]] = None, required_props: Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]] = None)[source]#

Bases: object

Configuration class for pre-compiling vInference functions.

This class holds parameters that define the shape and properties of inputs expected by the vInference engine during pre-compilation. It allows specifying different configurations, potentially in lists, to compile for multiple scenarios.

batch_size#

Batch size or list of batch sizes for text generation.

Type

Union[int, List[int]]

prefill_length#

Prefill sequence length or list of lengths. If None, it might be inferred or not used depending on the context.

Type

Optional[Union[int, List[int]]]

vision_included#

Whether vision inputs are included in the model.

Type

Union[bool, List[bool]]

vision_batch_size#

Batch size for vision inputs. Only relevant if vision_included is True.

Type

Optional[Union[int, List[int]]]

vision_channels#

Number of channels for vision inputs. Only relevant if vision_included is True.

Type

Optional[Union[int, List[int]]]

vision_height#

Height of vision inputs. Only relevant if vision_included is True.

Type

Optional[Union[int, List[int]]]

vision_width#

Width of vision inputs. Only relevant if vision_included is True.

Type

Optional[Union[int, List[int]]]

required_props#

Optional dictionary or list of dictionaries specifying required properties for advanced configuration (e.g., specific model arguments).

Type

Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]]

batch_size: Union[int, List[int]] = 1#
extract() dict[source]#

Converts the configuration instance into a dictionary.

This method is useful for serialization or easily accessing all configuration values.

Returns

A dictionary representation of the vInferencePreCompileConfig instance.

classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

get_default_hash() int[source]#

Generates a unique integer hash representing the configuration.

This hash is calculated based on the string representation of all configuration attributes, ensuring that identical configurations produce the same hash. This is crucial for caching compiled functions based on their configuration.

Returns

An integer hash value representing the configuration.

get_standalones() List[vInferencePreCompileConfig][source]#

Generates a list of standalone configurations from a potentially multi-value config.

If any attribute in the current configuration is a list (indicating multiple scenarios), this method expands the configuration into multiple individual vInferencePreCompileConfig instances. Each resulting instance represents a single, specific compilation scenario.

If an attribute’s list is shorter than the longest list among all attributes, its last element is repeated to ensure all generated configurations have values for all attributes.

If the original configuration is already standalone (no list attributes), this method returns a list containing only the original instance.

Returns

A list of vInferencePreCompileConfig instances, each representing a single, standalone compilation scenario.

prefill_length: Optional[Union[int, List[int]]] = None#
replace(**kwargs)#
required_props: Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]] = None#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

vision_batch_size: Optional[Union[int, List[int]]] = None#
vision_channels: Optional[Union[int, List[int]]] = None#
vision_height: Optional[Union[int, List[int]]] = None#
vision_included: Union[bool, List[bool]] = False#
vision_width: Optional[Union[int, List[int]]] = None#
class easydel.__init__.vWhisperInference(model: ~typing.Any, tokenizer: ~typing.Any, processor: ~typing.Any, inference_config: ~typing.Optional[~easydel.__init__.inference.vwhisper.config.vWhisperInferenceConfig] = None, dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>)[source]#

Bases: object

Whisper inference pipeline for performing speech-to-text transcription or translation.

Parameters
  • model (WhisperForConditionalGeneration) – The fine-tuned Whisper model to use for inference.

  • tokenizer (WhisperTokenizer) – Tokenizer for Whisper.

  • processor (WhisperProcessor) – Processor for Whisper.

  • inference_config (vWhisperInferenceConfig, optional) – Inference configuration.

  • dtype (jax.typing.DTypeLike, optional, defaults to jnp.float32) – Data type for computations.

generate(audio_input: Union[str, bytes, ndarray, Dict[str, Union[ndarray, int]]], chunk_length_s: float = 30.0, stride_length_s: Optional[Union[float, list[float]]] = None, batch_size: Optional[int] = None, language: Optional[str] = None, task: Optional[str] = None, return_timestamps: Optional[bool] = None)[source]#

Transcribe or translate audio input.

Parameters
  • audio_input (tp.Union[str, bytes, np.ndarray, tp.Dict[str, tp.Union[np.ndarray, int]]]) – Input audio. Can be a local file path, URL, bytes, numpy array, or a dictionary containing the array and sampling rate.

  • chunk_length_s (float, optional, defaults to 30.0) – Length of audio chunks in seconds.

  • stride_length_s (float or list[float], optional) – Stride length for chunking audio, in seconds. Defaults to chunk_length_s / 6.

  • batch_size (int, optional) – Batch size for processing. Defaults to the batch_size in inference_config.

  • language (str, optional) – Language of the input audio. Defaults to the language in inference_config.

  • task (str, optional) – Task to perform (e.g., “transcribe”, “translate”). Defaults to the task in inference_config.

  • return_timestamps (bool, optional) – Whether to return timestamps with the transcription. Defaults to the return_timestamps in inference_config.

Returns

A dictionary containing the transcribed text (“text”) and optionally other information like timestamps or detected language.

Return type

dict

class easydel.__init__.vWhisperInferenceConfig(batch_size: Optional[int] = 1, max_length: Optional[int] = None, generation_config: Optional[Any] = None)[source]#

Bases: object

Configuration class for Whisper inference.

Parameters
  • batch_size (int, optional, defaults to 1) – Batch size used for inference.

  • max_length (int, optional) – Maximum sequence length for generation.

  • generation_config (transformers.GenerationConfig, optional) – Generation configuration object.

  • logits_processor (optional) – Not used.

  • return_timestamps (bool, optional) – Whether to return timestamps with the transcribed text.

  • task (str, optional) – Task for the model (e.g., “transcribe”, “translate”).

  • language (str, optional) – Language of the input audio.

  • is_multilingual (bool, optional) – Whether the model is multilingual.

batch_size: Optional[int] = 1#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

generation_config: Optional[Any] = None#
is_multilingual = None#
language = None#
logits_processor = None#
max_length: Optional[int] = None#
replace(**kwargs)#
return_timestamps = None#
task = None#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.