easydel.__init__#
- class easydel.__init__.ArcticConfig(vocab_size=32000, hidden_size=4096, intermediate_size=14336, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=4096, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=1000000.0, sliding_window=None, attention_dropout=0.0, num_experts_per_tok=1, num_local_experts=8, router_aux_loss_coef=0.001, moe_layer_frequency=2, parallel_attn_mlp_res=False, moe_train_capacity_factor=1, moe_eval_capacity_factor=1, enable_expert_tensor_parallelism=False, moe_min_capacity=0, moe_token_dropping=True, quantization=None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, rope_scaling: Dict[str, Union[str, float]] = None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 32000) – Vocabulary size of the ARCTIC model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 14336) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.
hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
max_position_embeddings (int, optional, defaults to 4096) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional) – The index of the padding token in the vocabulary. The default value (0) is the same as for GPT2.
bos_token_id (int, optional) – The index of the beginning of sequence token in the vocabulary. The default value (1) is the same as for GPT2.
eos_token_id (int, optional) – The index of the end of sequence token in the vocabulary. The default value (2) is the same as for GPT2.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (float, optional, defaults to 1e6) – The theta value to use for rotary position embeddings.
sliding_window (int, optional) – The sliding window size to use for attention. If not specified, no sliding window attention is used.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
num_experts_per_tok (int, optional, defaults to 1) – The number of experts per token for mixture of experts.
num_local_experts (int, optional, defaults to 8) – The number of local experts for mixture of experts.
router_aux_loss_coef (float, optional, defaults to 0.001) – The auxiliary loss coefficient for the router.
moe_layer_frequency (int, optional, defaults to 2) – The frequency of MoE layers.
parallel_attn_mlp_res (bool, optional, defaults to False) – Whether to parallelize attention and MLP residual connections.
moe_train_capacity_factor (float, optional, defaults to 1) – The capacity factor for MoE layers during training.
moe_eval_capacity_factor (float, optional, defaults to 1) – The capacity factor for MoE layers during evaluation.
enable_expert_tensor_parallelism (bool, optional, defaults to False) – Whether to enable expert tensor parallelism.
moe_min_capacity (int, optional, defaults to 0) – The minimum capacity for MoE layers.
moe_token_dropping (bool, optional, defaults to True) – Whether to drop tokens in MoE layers.
quantization (str, optional) – The quantization configuration.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
use_scan_mlp (bool, optional, defaults to False) – Whether to use scan for MLP.
scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size for scan MLP.
bits (int, optional) – The number of bits.
rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The rope scaling configuration.
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- static get_weight_decay_exclusions()[source]#
Returns a tuple of parameter names for which weight decay should be excluded.
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size for frequency-based position embeddings, falling back to max_position_embeddings if freq_max_position_embeddings is not set.
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size for mask-based position embeddings, falling back to max_position_embeddings if mask_max_position_embeddings is not set.
- model_type: str = 'arctic'#
- class easydel.__init__.ArcticForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleArctic model specifically adapted for Causal Language Modeling (CLM). This module wraps the core ArcticModel and adds a language modeling head on top.
- config#
Configuration object for the Arctic model.
- Type
- dtype#
Data type for computation. Defaults to jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Defaults to jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Defaults to None.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators for the module.
- Type
nn.Rngs
- class easydel.__init__.ArcticModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleCore Arctic model architecture. This module implements the main Transformer stack for the Arctic model, including token embeddings and decoder layers.
- config#
Configuration object for the Arctic model.
- Type
- dtype#
Data type for computation. Defaults to jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Defaults to jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Defaults to None.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators for the module.
- Type
nn.Rngs
- class easydel.__init__.AttentionMechanisms(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,EnumEnumeration of available attention mechanisms.
- AUTO#
Automatically selects the best mechanism based on the backend.
- FLASH_ATTN2#
FlashAttention-2 implementation.
- RING#
RingAttention implementation.
- VANILLA#
Standard dot-product attention.
- SPLASH#
SplashAttention implementation (optimized for TPUs).
- CUDNN#
cuDNN implementation (GPU specific).
- BLOCKWISE#
Blockwise attention computation.
- SDPA#
Scaled Dot Product Attention (potentially uses JAX native SDPA).
- CUDA_FLASH_ATTN2#
CUDA specific FlashAttention-2 implementation.
- PAGED_ATTENTION#
Paged attention for fast inference.
- AUTO = 'auto'#
- BLOCKWISE = 'blockwise'#
- CUDA_FLASH_ATTN2 = 'cuda_flash_attn2'#
- CUDNN = 'cudnn'#
- FLASH_ATTN2 = 'flash_attn2'#
- PAGED_ATTENTION = 'paged_attention'#
- RING = 'ring'#
- SDPA = 'sdpa'#
- SPLASH = 'splash'#
- VANILLA = 'vanilla'#
- class easydel.__init__.AttentionMetadata(runtime_dtype: Union[str, type[Any], dtype, SupportsDType], runtime_softmax_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None, sequence_axis_name: str = Ellipsis, mesh: Optional[Mesh] = Ellipsis, platform: EasyDeLPlatforms = Ellipsis, backend: EasyDeLBackends = Ellipsis, partition_axis: PartitionAxis = Ellipsis, base_config: Optional[EasyDeLBaseConfig] = None, scan_ring_attention: bool = Ellipsis, softmax_scale: float = Ellipsis, dropout_prob: float = Ellipsis, blocksize_q: int = Ellipsis, blocksize_k: int = Ellipsis, blocksize_b: int = Ellipsis)[source]#
Bases:
objectHolds configuration, context, and metadata for attention operations.
This class centralizes various parameters needed by different attention implementations, facilitating consistent behavior and configuration. It handles default values and can be initialized from an EasyDeLBaseConfig.
- runtime_dtype#
The primary JAX dtype for computations (e.g., q, k, v).
- Type
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]
- runtime_softmax_dtype#
Optional JAX dtype for the softmax computation, allowing for higher precision if needed (e.g., float32).
- Type
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]]
- sequence_axis_name#
The name used for the sequence axis in JAX parallelism (axis_names for pjit).
- Type
str
- mesh#
The JAX device mesh for distributed computation. Must be provided or inferred from context.
- Type
Optional[jax._src.mesh.Mesh]
- platform#
The target hardware platform (e.g., TPU, GPU).
- backend#
The specific JAX backend being used (e.g., TPU, CUDA, ROCM).
- partition_axis#
Configuration for partitioning axes in distributed settings. (Likely from eformer.escale).
- base_config#
An optional reference to the base model configuration object for sourcing default values.
- Type
- scan_ring_attention#
Boolean flag indicating whether to use ring attention via jax.lax.scan.
- Type
bool
- softmax_scale#
The scaling factor applied before the softmax operation. Often 1 / sqrt(head_dim).
- Type
float
- dropout_prob#
The dropout probability applied to attention weights.
- Type
float
- blocksize_q#
Block size for the query sequence dimension in blockwise attention.
- Type
int
- blocksize_k#
Block size for the key/value sequence dimension in blockwise attention.
- Type
int
- blocksize_b#
Block size for the batch dimension in blockwise attention (often 1).
- Type
int
- backend: EasyDeLBackends = Ellipsis#
- base_config: Optional[EasyDeLBaseConfig] = None#
- blocksize_b: int = Ellipsis#
- blocksize_k: int = Ellipsis#
- blocksize_q: int = Ellipsis#
- dropout_prob: float = Ellipsis#
- classmethod from_config(config: EasyDeLBaseConfig, softmax_scale: float, dropout_prob: float = 0.0) AttentionMetadata[source]#
Factory method to create AttentionMetadata from an EasyDeLBaseConfig.
- Parameters
config – The base configuration object (e.g., model config).
softmax_scale – The attention softmax scaling factor. Usually calculated based on head dimension.
dropout_prob – The attention dropout probability. Defaults to 0.0.
- Returns
An initialized AttentionMetadata instance.
- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- get_partition_specs(mode: RuntimeType, BTHD: bool = True)[source]#
Generates JAX PartitionSpecs for attention tensors based on runtime mode.
- Parameters
mode – The current runtime mode (normal or generation).
BTHD – Boolean indicating tensor layout. True for (Batch, Time, Head, Dim), False for (Batch, Head, Time, Dim).
- Returns
(query, key, value, bias, mask, attention_output)
- Return type
A tuple containing PartitionSpecs for
- partition_axis: PartitionAxis = Ellipsis#
- platform: EasyDeLPlatforms = Ellipsis#
- replace(**kwargs)#
- scan_ring_attention: bool = Ellipsis#
- sequence_axis_name: str = Ellipsis#
- set_attrs_carefully(attr_name: str, default: Optional[Any], pickup_name: Optional[str] = None, use_base_config: bool = True)[source]#
Internal helper to set an attribute if it’s not already set (or is Ellipsis).
Optionally retrieves the value from self.base_config using pickup_name (or attr_name if pickup_name is None).
- Parameters
attr_name – The name of the attribute to set on self.
default – The default value to use if not found in base_config or if use_base_config is False.
pickup_name – The name of the attribute to look for in base_config. Defaults to attr_name.
use_base_config – Whether to attempt retrieving the value from base_config.
- softmax_scale: float = Ellipsis#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.
- class easydel.__init__.AttentionModule(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleBase class for Flax attention modules in EasyDeL, providing common utilities.
This class offers helper functions and attributes commonly needed by attention implementations within Flax, such as handling KV caching, sharding, mask manipulation, and head manipulation. Concrete attention implementations often inherit from this class.
- config#
Configuration object for the attention module.
- Type
SC | EasyDeLBaseConfig
- cached_key#
Flax Cache for storing past key states (wont be used).
- Type
nn.Cache[Array] | None
- cached_value#
Flax Cache for storing past value states (wont be used).
- Type
nn.Cache[Array] | None
- cache_index#
Flax Cache for tracking the current index in the cache (wont be used).
- Type
nn.Cache[Array] | None
- static build_cache_pos(attention_mask: Array, cache_view: TransformerCacheView = None) Array[source]#
Calculates the position indices within the sequence for cache-aware operations.
- Parameters
attention_mask (jax.Array) – The attention mask (typically [batch, heads, q_len, k_len]).
cache_view (TransformerCacheView, optional) – The current KV cache view. Defaults to None.
- Returns
- An array representing the position of each token in the sequence,
adjusted by the cache index if provided. Shape usually [batch, q_len].
- Return type
- concatenate(*, query: Union[Array, ndarray, bool, number], key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], attention_mask: Union[Array, ndarray, bool, number], cache_view: Optional[Union[TransformerCacheView, PagedAttentionCacheView]] = None, cache_metadata: Optional[Union[TransformerMetadata, PagedAttentionMetadata]] = None, causal_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None, fcm_mask: Optional[Union[Array, ndarray, bool, number]] = None, sliding_windows: Optional[int] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Callable[[], Union[Array, ndarray, bool, number]]][source]#
Prepares inputs for attention calculation, handling KV caching and mask merging.
This function combines the current query, key, and value with cached states (if applicable), merges various masks (attention, causal, FCM, sliding window), and returns the final key, value, attention mask, and a function to initialize the attention bias.
- Parameters
query (Array) – Current query states [Batch, q_len, Heads, Dim].
key (Array) – Current key states [Batch, kv_len, Heads, Dim].
value (Array) – Current value states [Batch, kv_len, Heads, Dim].
attention_mask (Array) – Base attention mask (e.g., padding mask) [Batch, kv_len] or compatible.
cache_view (tp.Optional[TransformerCacheView], optional) – View into the KV cache. If None, caching is disabled. Defaults to None.
causal_mask (tp.Optional[Array], optional) – Causal mask [1, 1, q_len, kv_len]. Defaults to None.
token_type_ids (tp.Optional[Array], optional) – Token type IDs for segment masking [Batch, q_len]. Defaults to None.
fcm_mask (tp.Optional[Array], optional) – Fused-Context-Mask (specific use case) [Batch, 1, q_len, kv_len]. Defaults to None.
sliding_windows (tp.Optional[int], optional) – Size of the sliding attention window. If None, not applied. Defaults to None.
- Returns
key_states (Array): Final key states (potentially from cache).
value_states (Array): Final value states (potentially from cache).
attention_mask (Array): The final combined attention mask [Batch, Heads, q_len, kv_len].
init_attention_bias (Callable): Function to create the attention bias tensor.
- Return type
tp.Tuple[Array, Array, Array, tp.Callable[[], Array]]
- property default_key_value_sharding#
Defines the default JAX sharding for key and value tensors.
Uses the partition specifications defined in the configuration’s partition_axis.
- Returns
The default sharding configuration for K/V tensors.
- Return type
NamedSharding
- get_sharding_safely(tensor: Array) PartitionSpec[source]#
Retrieves the PartitionSpec of a tensor, falling back to the default KV sharding.
- Parameters
tensor (jax.Array) – The tensor whose sharding spec is needed.
- Returns
The sharding specification of the tensor.
- Return type
PartitionSpec
- make_flexible_sliding_window(attention_mask: Array, cache_view: TransformerCacheView, sliding_window: int)[source]#
Applies a sliding window mask to the attention mask, considering cache state.
- Parameters
attention_mask (jax.Array) – The original attention mask.
cache_view (TransformerCacheView) – The current view of the KV cache.
sliding_window (int) – The size of the sliding window.
- Returns
The attention mask combined with the sliding window mask.
A function (init_attention_bias) to create the corresponding attention bias.
- Return type
- property quantizer#
Provides an EasyQuantizer instance based on the module’s configuration.
Used for quantizing KV cache entries if enabled in the config.
- Returns
The quantizer instance.
- Return type
- static repeat_key_value(key, value, num_reps: int)[source]#
Repeats key and value tensors for Grouped Query Attention (GQA).
Expands the head dimension by repeating num_reps times. Uses einops for concise repetition.
- Parameters
key (Array) – Key tensor [Batch, Seq, NumKVHeads, Dim].
value (Array) – Value tensor [Batch, Seq, NumKVHeads, Dim].
num_reps (int) – The number of times to repeat each KV head (num_attention_heads / num_kv_heads).
- Returns
- Repeated key and value tensors, each with shape
[Batch, Seq, NumKVHeads * num_reps, Dim].
- Return type
tp.Tuple[Array, Array]
- class easydel.__init__.AttentionRegistry[source]#
Bases:
objectRegistry for discovering and managing different AttentionImpl classes.
Allows registering implementations using a decorator and retrieving or instantiating them by name.
- classmethod create(impl_name: str, metadata: AttentionMetadata) AttentionImpl[source]#
Creates an instance of an attention implementation by name.
Retrieves the class associated with impl_name and initializes it with the provided metadata.
- Parameters
impl_name – The name of the implementation to instantiate.
metadata – The AttentionMetadata to pass to the implementation’s constructor.
- Returns
An initialized instance of the requested AttentionImpl subclass.
- Raises
ValueError – If no implementation is registered with impl_name.
- classmethod get(impl_name: str) Type[AttentionImpl][source]#
Retrieves an attention implementation class by its registered name.
- Parameters
impl_name – The name of the implementation to retrieve.
- Returns
The AttentionImpl subclass registered under the given name.
- Raises
ValueError – If no implementation is registered with that name.
- classmethod list_implementations() List[str][source]#
Returns a list of names of all registered attention implementations.
- Returns
A list of strings, where each string is a registered implementation name.
- classmethod register(impl_cls: Type[ICa]) Type[ICa][source]#
Class method decorator to register an AttentionImpl subclass.
The implementation is registered under the name(s) returned by its get_impl_name() class method.
Example: ```python @AttentionRegistry.register class FlashAttentionImpl(AttentionImpl):
@classmethod def get_impl_name(cls) -> str:
return “flash”
# … implementation …
- Parameters
impl_cls – The AttentionImpl subclass to register.
- Returns
The registered class itself.
- class easydel.__init__.AutoEasyDeLConfig[source]#
Bases:
object- classmethod from_pretrained(pretrained_model_name_or_path: str, sharding_axis_dims: Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: Optional[Sequence[int]] = None, sharding_axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: Optional[PartitionAxis] = None, shard_attention_computation: bool = True, backend: Optional[EasyDeLBackends] = None, platform: Optional[EasyDeLPlatforms] = None, model_task: TaskType = TaskType.CAUSAL_LM, from_torch: bool = False, **kwargs) EasyDeLBaseConfig[source]#
The from_pretrained function is a helper function that allows you to instantiate a model from the pretrained model repository. It takes as input the name of the model (e.g., ‘bert-base-uncased’) and returns an instance of the class corresponding to your model, with all weights loaded from disk.
- Parameters
cls – Create an instance of the class that called this function.
pretrained_model_name_or_path – str: Identify the model in the huggingface model hub.
sharding_axis_dims – tp.Sequence[int]: Specify the dimension of each axis in the sharded model_tasking arrays in easydel.
shard_attention_computation – bool: whenever to use shard_map for attention.
backend – tp.Optional[EasyDeLBackends] : backend to use for model. model_task (TaskType): Task type of model load and find.
from_torch – should config be loaded from torch models or not.
**kwargs – Pass additional arguments to the model and config classes.
generation process
- Returns
A Model Config
- class easydel.__init__.AutoEasyDeLModel[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- class easydel.__init__.AutoEasyDeLModelForCausalLM[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- None#
Examples
>>> import jax >>> from easydel import AutoEasyDeLModelForCausalLM
>>> # Load a GPT-2 model on a single CPU >>> model = AutoEasyDeLModelForCausalLM.from_pretrained( >>> "gpt2", device=jax.devices("cpu")[0] >>> )
>>> # Load a GPT-2 model sharded across 8 GPUs with data parallelism (DP) and fully sharded data parallelism (FSDP) >>> model = AutoEasyDeLModelForCausalLM.from_pretrained( ... "gpt2", ... sharding_axis_dims=(1, 8, 1, 1), ... sharding_axis_names=("dp", "fsdp", "tp", "sp"), ... device=jax.devices("cpu")[0], # offload to CPU [OPTIONAL] ... from_torch=True, >>> ) ```
- class easydel.__init__.AutoEasyDeLModelForImageTextToText[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- class easydel.__init__.AutoEasyDeLModelForSeq2SeqLM[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- class easydel.__init__.AutoEasyDeLModelForSequenceClassification[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- class easydel.__init__.AutoEasyDeLModelForSpeechSeq2Seq[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- None#
Examples
>>> import jax >>> from easydel import AutoEasyDeLModelForSpeechSeq2Seq
>>> # Load a openai/whisper-large-v3-turbo sharded >>> model = AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained( ... "openai/whisper-large-v3-turbo", ... auto_shard_model=True, >>> )
>>> # Load a openai/whisper-large-v3-turbo model sharded across 8 GPUs with data parallelism (DP) and fully sharded data parallelism (FSDP) >>> model = AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained( ... "openai/whisper-large-v3-turbo", ... sharding_axis_dims=(1, 8, 1, 1), ... sharding_axis_names=("dp", "fsdp", "tp", "sp"), ... device=jax.devices("cpu")[0], # offload to CPU [OPTIONAL] ... from_torch=True, >>> ) ```
- class easydel.__init__.AutoEasyDeLModelForZeroShotImageClassification[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- class easydel.__init__.AutoEasyDeLVisionModel[source]#
Bases:
BaseAutoEasyModelThis class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.
This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.
- class easydel.__init__.AutoShardAndGatherFunctions[source]#
Bases:
objectA class to automatically generate shard and gather functions for a given model configuration.
This class provides two methods to generate shard and gather functions:
from_config: Generates functions based on a provided EasyDeLBaseConfig object.
from_pretrained: Generates functions based on a pretrained model name or path.
- None#
- from_config()[source]#
Generates shard and gather functions based on a provided EasyDeLBaseConfig object.
- classmethod from_config(config: EasyDeLBaseConfig, partition_rules: Optional[Tuple[Tuple[str, PartitionSpec]]] = None, flatten: bool = True, model_task: TaskType = TaskType.CAUSAL_LM, depth_target: Optional[List[str]] = None)[source]#
Generates shard and gather functions based on a provided EasyDeLBaseConfig object.
- Parameters
config – An EasyDeLBaseConfig object containing the model configuration.
partition_rules – A tuple of tuples containing partition rule names and PartitionSpec objects. If None, uses the default partition rules from the config.
flatten – Whether to flatten the shard and gather functions. Defaults to True. model_task (TaskType): Task type of model load and find.
depth_target – Pad the sharding to depth, for example make {params:tensor} with depth_target = [“row”] to {row:{params:tensor}}. Defaults to None.
- Returns
A tuple containing the shard and gather functions.
- static from_params(params, partition_rules, mesh)[source]#
Generates shard and gather functions directly from model parameters, partition rules, and a mesh.
- Parameters
params – The model parameters (pytree) to generate functions for.
partition_rules – A tuple of tuples defining the partitioning strategy.
mesh – The JAX device mesh to use for sharding.
- Returns
A tuple containing the shard and gather functions.
- classmethod from_pretrained(pretrained_model_name_or_path: str, sharding_axis_dims: Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: Optional[Sequence[int]] = None, sharding_axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: Optional[PartitionAxis] = None, shard_attention_computation: bool = True, backend: Optional[EasyDeLBackends] = None, platform: Optional[EasyDeLPlatforms] = None, partition_rules: Optional[Tuple[Tuple[str, PartitionSpec]]] = None, flatten: bool = True, config_kwargs: Optional[Mapping[str, Any]] = None, model_task: TaskType = TaskType.CAUSAL_LM, from_torch: bool = False, trust_remote_code: bool = False) Tuple[Mapping[str, Callable], Mapping[str, Callable]][source]#
Generates shard and gather functions based on a pretrained model name or path.
- Parameters
pretrained_model_name_or_path – The name or path of the pretrained model.
sharding_axis_dims – The dimensions of the sharding axes. Defaults to (1, -1, 1, 1).
sharding_axis_names – The names of the sharding axes. Defaults to (“dp”, “fsdp”, “tp”, “sp”).
partition_axis (PartitionAxis) – PartitionAxis is new module used for partitioning arrays in easydel.
shard_attention_computation – Whether to shard the attention computation. Defaults to True.
backend – The backend to use for custom kernels. Defaults to None.
partition_rules – A tuple of tuples containing partition rule names and PartitionSpec objects. If None, uses the default partition rules from the config.
flatten – Whether to flatten the shard and gather functions. Defaults to True.
config_kwargs – Additional keyword arguments to pass to the AutoEasyDeLConfig constructor. Defaults to None. model_task (TaskType): Task type of model load and find. from_torch: should config be loaded from torch models or not.
trust_remote_code (bool) – whenever to trust remote code loaded from HF.
- Returns
A tuple containing the shard and gather functions.
- class easydel.__init__.AyaVisionConfig(vision_config=None, text_config=None, vision_feature_select_strategy='full', vision_feature_layer=-1, downsample_factor=2, adapter_layer_norm_eps=1e-06, image_token_index=255036, **kwargs)[source]#
Bases:
EasyDeLBaseConfigThis is the configuration class to store the configuration of a [AyaVisionForConditionalGeneration]. It is used to instantiate an AyaVision model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of AyaVision. e.g. [CohereForAI/aya-vision-8b](https://huggingface.co/CohereForAI/aya-vision-8b)
Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.
- Parameters
vision_config (Union[AutoConfig, dict], optional, defaults to CLIPVisionConfig) – The config object or dictionary of the vision backbone.
text_config (Union[AutoConfig, dict], optional, defaults to LlamaConfig) – The config object or dictionary of the text backbone.
vision_feature_select_strategy (str, optional, defaults to “full”) – The feature selection strategy used to select the vision feature from the vision backbone. Can be one of “default” or “full”. If “default”, the CLS token is removed from the vision features. If “full”, the full vision features are used.
vision_feature_layer (int, optional, defaults to -1) – The index of the layer to select the vision feature.
downsample_factor (int, optional, defaults to 2) – The downsample factor to apply to the vision features.
adapter_layer_norm_eps (float, optional, defaults to 1e-06) – The epsilon value used for layer normalization in the adapter.
image_token_index (int, optional, defaults to 255036) – The image token index to encode the image prompt.
- get_partition_rules(*args, **kwargs)[source]#
Retrieves the combined partition rules from the text and vision configurations.
- Parameters
*args – Positional arguments passed to the underlying config partition rule methods.
**kwargs – Keyword arguments passed to the underlying config partition rule methods.
- Returns
Combined partition rules from both text and vision models.
- Return type
Tuple
- model_type: str = 'aya_vision'#
- sub_configs: dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.auto.auto_configuration.AutoEasyDeLConfig'>, 'vision_config': <class 'easydel.__init__.modules.auto.auto_configuration.AutoEasyDeLConfig'>}#
- class easydel.__init__.AyaVisionForConditionalGeneration(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleAyaVision model for conditional text generation based on image inputs. Combines a vision tower and a language model with a multi-modal projector.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- get_image_features(pixel_values: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#
Extracts and projects image features from the vision tower.
- Parameters
pixel_values (chex.Array) – Input pixel values for the images.
- Returns
Processed image features ready for the language model.
- Return type
chex.Array
- loss_type = 'ForCausalLM'#
- prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
Prepares inputs for text generation, including pixel values if provided.
- Parameters
input_ids (chex.Array) – Initial input token IDs.
max_length (int) – Maximum generation length.
pixel_values (Optional[chex.Array]) – Pixel values for image input.
attention_mask (Optional[chex.Array]) – Attention mask.
- Returns
Model inputs ready for generation.
- Return type
dict
- update_inputs_for_generation(model_outputs, model_kwargs)[source]#
Updates model inputs for the next step of generation, removing pixel values after the first step.
- Parameters
model_outputs – Outputs from the previous generation step.
model_kwargs – Current keyword arguments for the model.
- Returns
Updated model keyword arguments.
- Return type
dict
- class easydel.__init__.BaseTrainer(arguments: tp.Optional[TrainingArguments] = None, model_state: tp.Optional[EasyDeLState] = None, model: tp.type[EasyDeLBaseModule] = None, dataset_train: tp.Optional[Dataset] = None, dataset_eval: tp.Optional[Dataset] = None, data_collator: tp.Optional[tp.Callable] = None, finetune: bool = True, checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]] = None, **deprecated_kwargs)[source]#
Bases:
BaseTrainerProtocol- apply_training_hooks(metrics: LossMetrics) LossMetrics[source]#
Apply training hooks to the model.
- configure_dataloaders() TrainerConfigureDataloaderOutput[source]#
Configures the dataloaders for training and evaluation.
This method creates the training and evaluation dataloaders using the provided datasets and data collator. It also determines the maximum number of training and evaluation steps based on the dataset sizes and training arguments.
- Returns
- An object containing the configured dataloaders and the
maximum number of training and evaluation steps.
- Return type
- abstract configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- This method sets up the necessary functions for training and evaluation, including:
Initialization of the model state.
Sharding of the model parameters and optimizer state.
JIT-compilation of the training and evaluation step functions.
- Returns
An object containing the configured functions and other relevant information.
- Return type
- configure_model() TrainerConfigureModelOutput[source]#
Configures the model, optimizer, scheduler, and configuration.
This method retrieves the model configuration from the model state, creates the optimizer and scheduler using the training arguments, and returns an object containing the configured model, optimizer, scheduler, and configuration.
- Returns
An object containing the configured model, optimizer, scheduler, and configuration.
- Return type
- abstract create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#
Creates a function to collect and process batches of data for training or evaluation.
This function handles padding or truncating sequences to the specified max_sequence_length based on the chosen truncation_mode.
- Parameters
max_sequence_length (int) – The maximum allowed sequence length.
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode. Defaults to “keep_end”.
- Returns
A function that takes a batch of data and returns a processed batch.
- Return type
tp.Callable
- create_progress_bar(total: int, desc: str = '', disabled: bool = False) BaseProgressBar[source]#
Create a progress bar of the specified type.
- property evaluation_batch_size#
- get_runstage_flops(is_training) Union[float, Tuple[float, bool]][source]#
Return the total number of FLOPs for the model.
- initialize_trainer_utils()[source]#
Initializes various utilities used by the trainer.
This includes setting up Weights & Biases, initializing the training timer, configuring dataloaders, configuring the model and optimizer, sharding the model and reference model states, and configuring the training and evaluation functions.
- property is_process_zero#
- log_metrics(metrics: Any, pbar: BaseProgressBar, step: int, mode: str = 'train')[source]#
Log metrics and update progress bar.
- log_weight_distribution(state: EasyDeLState, step: int)[source]#
Log distribution of weights.
- property mesh#
- property model#
- on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#
hook process to call in start of the step.
- on_step_start(state: EasyDeLState, step: int) EasyDeLState[source]#
hook process to call in start of the step.
- save_information(output_path: Union[str, Path]) None[source]#
Save the generated information to a markdown file.
- Parameters
output_path – Path where the markdown file should be saved
- save_pretrained(state: EasyDeLState, save_directory: Optional[str] = None, gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[Callable]]] = None, to_torch: bool = False, easystate_to_huggingface_model_kwargs: Optional[dict] = None, torch_save_pretrained_kwargs: Optional[dict] = None)[source]#
Saves the model state as a checkpoint file or to a Torch compatible directory.
- property training_batch_size#
- class easydel.__init__.CLIPConfig(text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs)[source]#
Bases:
EasyDeLBaseConfig[CLIPConfig] is the configuration class to store the configuration of a [CLIPModel]. It is used to instantiate a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the CLIP [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
text_config (dict, optional) – Dictionary of configuration options used to initialize [CLIPTextConfig].
vision_config (dict, optional) – Dictionary of configuration options used to initialize [CLIPVisionConfig].
projection_dim (int, optional, defaults to 512) – Dimensionality of text and vision projection layers.
logit_scale_init_value (float, optional, defaults to 2.6592) – The initial value of the logit_scale parameter. Default is used as per the original CLIP implementation.
kwargs (optional) – Dictionary of keyword arguments.
Example:
```python >>> from transformers import CLIPConfig, CLIPModel
>>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration >>> configuration = CLIPConfig()
>>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration >>> model = CLIPModel(configuration)
>>> # Accessing the model configuration >>> configuration = model.config
>>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig >>> from transformers import CLIPTextConfig, CLIPVisionConfig
>>> # Initializing a CLIPText and CLIPVision configuration >>> config_text = CLIPTextConfig() >>> config_vision = CLIPVisionConfig()
>>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision) ```
- classmethod from_text_vision_configs(text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs)[source]#
Instantiate a [CLIPConfig] (or a derived class) from clip text model configuration and clip vision model configuration.
- Parameters
text_config (CLIPTextConfig) – The text model configuration.
vision_config (CLIPVisionConfig) – The vision model configuration.
**kwargs – Additional keyword arguments.
- Returns
An instance of a configuration object
- Return type
[CLIPConfig]
- get_partition_rules(*arg, **kwargs)#
Generic partition rules for CLIP text and vision models.
- Parameters
self – The configuration object (unused but part of method signature).
*arg – Additional positional arguments (unused).
**kwargs – Additional keyword arguments (unused).
- Returns
A tuple of partition rules for model parameters.
- Return type
Tuple
- model_type: str = 'clip'#
- sub_configs: dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.clip.clip_configuration.CLIPTextConfig'>, 'vision_config': <class 'easydel.__init__.modules.clip.clip_configuration.CLIPVisionConfig'>}#
- class easydel.__init__.CLIPForImageClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleCLIP vision model with an image classification head on top (a linear layer on the pooled final hidden state).
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.__init__.CLIPModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- compute_loss(*, labels=None, loss_config=None, loss_kwargs=None, **batch) Tuple[Any, CLIPOutput][source]#
Computes the loss for the model given a batch of inputs and labels.
This method performs a forward pass using the provided batch arguments, then calculates the loss using the determined loss_function. It handles potential label inference (e.g., using input_ids as labels for Causal LM) and default loss configurations.
- Parameters
labels (tp.Optional[chex.Array], optional) – The target labels. If None and the task is Causal LM, input_ids from the batch might be used. Defaults to None.
loss_config (tp.Optional[LossConfig], optional) – Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.
loss_kwargs (tp.Optional[tp.Dict], optional) – Additional keyword arguments to pass directly to the loss function. Defaults to None.
**batch – Keyword arguments representing the input batch (e.g., input_ids, attention_mask).
- Returns
- A tuple containing:
The model’s output ( Pytree typically including logits, hidden states etc.)
A LossMetrics object containing the calculated loss and potentially other metrics.
- Return type
tp.Tuple[tp.Any, LossMetrics]
- Raises
AssertionError – If labels are required for the loss function but are not provided or inferred.
AssertionError – If sequence classification loss is used without num_labels in the config.
- class easydel.__init__.CLIPTextConfig(vocab_size=49408, hidden_size=512, intermediate_size=2048, projection_dim=512, num_hidden_layers=12, num_attention_heads=8, max_position_embeddings=77, hidden_act='quick_gelu', layer_norm_eps=1e-05, attention_dropout=0.0, initializer_range=0.02, initializer_factor=1.0, pad_token_id=1, bos_token_id=49406, eos_token_id=49407, **kwargs)[source]#
Bases:
EasyDeLBaseConfigThis is the configuration class to store the configuration of a [CLIPTextModel]. It is used to instantiate a CLIP text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the text encoder of the CLIP [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 49408) – Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [CLIPModel].
hidden_size (int, optional, defaults to 512) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 2048) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
projection_dim (int, optional, defaults to 512) – Dimensionality of text and vision projection layers.
num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 8) – Number of attention heads for each attention layer in the Transformer encoder.
max_position_embeddings (int, optional, defaults to 77) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
hidden_act (str or function, optional, defaults to “quick_gelu”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “selu” and “gelu_new” “quick_gelu” are supported.
layer_norm_eps (float, optional, defaults to 1e-05) – The epsilon used by the layer normalization layers.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_factor (float, optional, defaults to 1.0) – A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing).
pad_token_id (int, optional, defaults to 1) – Padding token id.
bos_token_id (int, optional, defaults to 49406) – Beginning of stream token id.
eos_token_id (int, optional, defaults to 49407) – End of stream token id.
Example:
```python >>> from transformers import CLIPTextConfig, CLIPTextModel
>>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration >>> configuration = CLIPTextConfig()
>>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration >>> model = CLIPTextModel(configuration)
>>> # Accessing the model configuration >>> configuration = model.config ```
- base_config_key: str = 'text_config'#
- get_partition_rules(*arg, **kwargs)#
Generic partition rules for CLIP text and vision models.
- Parameters
self – The configuration object (unused but part of method signature).
*arg – Additional positional arguments (unused).
**kwargs – Additional keyword arguments (unused).
- Returns
A tuple of partition rules for model parameters.
- Return type
Tuple
- model_type: str = 'clip_text_model'#
- class easydel.__init__.CLIPTextModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleBare CLIP text model (transformer) outputting raw hidden-states without any specific head on top.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.__init__.CLIPTextModelWithProjection(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleCLIP text model with a projection layer on top.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.__init__.CLIPVisionConfig(hidden_size=768, intermediate_size=3072, projection_dim=512, num_hidden_layers=12, num_attention_heads=12, num_channels=3, image_size=224, patch_size=32, hidden_act='quick_gelu', layer_norm_eps=1e-05, attention_dropout=0.0, initializer_range=0.02, initializer_factor=1.0, **kwargs)[source]#
Bases:
EasyDeLBaseConfigThis is the configuration class to store the configuration of a [CLIPVisionModel]. It is used to instantiate a CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.
Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
projection_dim (int, optional, defaults to 512) – Dimensionality of text and vision projection layers.
num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.
num_channels (int, optional, defaults to 3) – The number of input channels.
image_size (int, optional, defaults to 224) – The size (resolution) of each image.
patch_size (int, optional, defaults to 32) – The size (resolution) of each patch.
hidden_act (str or function, optional, defaults to “quick_gelu”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “selu” and “gelu_new” “quick_gelu” are supported.
layer_norm_eps (float, optional, defaults to 1e-05) – The epsilon used by the layer normalization layers.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_factor (float, optional, defaults to 1.0) – A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing).
Example:
```python >>> from transformers import CLIPVisionConfig, CLIPVisionModel
>>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration >>> configuration = CLIPVisionConfig()
>>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration >>> model = CLIPVisionModel(configuration)
>>> # Accessing the model configuration >>> configuration = model.config ```
- base_config_key: str = 'vision_config'#
- get_partition_rules(*arg, **kwargs)#
Generic partition rules for CLIP text and vision models.
- Parameters
self – The configuration object (unused but part of method signature).
*arg – Additional positional arguments (unused).
**kwargs – Additional keyword arguments (unused).
- Returns
A tuple of partition rules for model parameters.
- Return type
Tuple
- model_type: str = 'clip_vision_model'#
- class easydel.__init__.CLIPVisionModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleBare CLIP vision model (transformer) outputting raw hidden-states without any specific head on top.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.__init__.Cohere2Config(vocab_size=256000, hidden_size=8192, intermediate_size=22528, logit_scale=0.0625, num_hidden_layers=40, num_attention_heads=64, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=8192, initializer_range=0.02, layer_norm_eps=1e-05, use_cache=True, pad_token_id=0, bos_token_id=5, eos_token_id=255001, tie_word_embeddings=True, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, sliding_window=4096, sliding_window_pattern=4, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 256000) – Vocabulary size of the Cohere model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 8192) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 22528) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
logit_scale (float, optional, defaults to 0.0625) – A logit scale value used in the attention layer.
num_hidden_layers (int, optional, defaults to 40) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 64) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.
hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
max_position_embeddings (int, optional, defaults to 8192) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.
bos_token_id (int, optional, defaults to 5) – The index of the beginning of sequence token in the vocabulary.
eos_token_id (int, optional, defaults to 255001) – The index of the end of sequence token in the vocabulary.
tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
attention_bias (bool, optional, defaults to False) – Whether to use attention bias.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
use_qk_norm (bool, optional, defaults to False) – Whether to use query and key normalization.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
bits (int, optional) – The number of bits to quantize the model to.
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- static get_weight_decay_exclusions()[source]#
Returns a tuple of parameter names for which weight decay should be excluded.
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size for frequency-based position embeddings, falling back to max_position_embeddings.
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size for mask-based position embeddings, falling back to max_position_embeddings.
- model_type: str = 'cohere'#
- class easydel.__init__.Cohere2ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.Cohere2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.Cohere2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.CohereConfig(vocab_size=256000, hidden_size=8192, intermediate_size=22528, logit_scale=0.0625, num_hidden_layers=40, num_attention_heads=64, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=8192, initializer_range=0.02, layer_norm_eps=1e-05, use_cache=True, pad_token_id=0, bos_token_id=5, eos_token_id=255001, tie_word_embeddings=True, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, use_qk_norm: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 256000) – Vocabulary size of the Cohere model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 8192) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 22528) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
logit_scale (float, optional, defaults to 0.0625) – A logit scale value used in the attention layer.
num_hidden_layers (int, optional, defaults to 40) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 64) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.
hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
max_position_embeddings (int, optional, defaults to 8192) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.
bos_token_id (int, optional, defaults to 5) – The index of the beginning of sequence token in the vocabulary.
eos_token_id (int, optional, defaults to 255001) – The index of the end of sequence token in the vocabulary.
tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
attention_bias (bool, optional, defaults to False) – Whether to use attention bias.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
use_qk_norm (bool, optional, defaults to False) – Whether to use query and key normalization.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
bits (int, optional) – The number of bits to quantize the model to.
- attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
The attach_custom_arguments function adds the following arguments to the Transformer class:
- Parameters
self – Refer to the current object
gradient_checkpointing (EasyDeLGradientCheckPointers) – Control the amount of memory used by jax
bits (Optional[int]) – Determine the number of bits used in the quantization
**kwargs – Additional keyword arguments.
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- static get_weight_decay_exclusions()[source]#
Returns a tuple of parameter names for which weight decay should be excluded.
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size for frequency-based position embeddings, falling back to max_position_embeddings.
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size for mask-based position embeddings, falling back to max_position_embeddings.
- model_type: str = 'cohere'#
- class easydel.__init__.CohereForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.CohereForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleCohere model for sequence classification.
- config#
Configuration object (must include num_labels).
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.__init__.CohereModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.ConfigType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,EnumEnumeration defining types of configurations that can be registered.
- MODULE_CONFIG#
Represents standard module configuration classes.
- MODULE_CONFIG = 'module-config'#
- class easydel.__init__.DPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 1e-06, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'DPOTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, beta: float = 0.1, label_smoothing: float = 0.0, loss_type: ~typing.Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = 'sigmoid', use_weighting: bool = False, label_pad_token_id: int = -100, padding_value: ~typing.Optional[int] = None, max_length: ~typing.Optional[int] = 512, max_prompt_length: ~typing.Optional[int] = 256, max_completion_length: ~typing.Optional[int] = None, is_encoder_decoder: ~typing.Optional[bool] = None, disable_dropout: bool = True, precompute_ref_log_probs: bool = False, dataset_num_proc: ~typing.Optional[int] = None, reference_free: bool = False, force_use_ref_model: bool = False, sync_ref_model: bool = False, ref_model_mixup_alpha: float = 0.9, ref_model_sync_steps: int = 64, rpo_alpha: ~typing.Optional[float] = None, tools: ~typing.Optional[~typing.List[~typing.Union[dict, ~typing.Callable]]] = None)[source]#
Bases:
TrainingArgumentsConfiguration class for Direct Preference Optimization (DPO) training.
Inherits from TrainingArguments and adds parameters specific to DPO training as described in https://arxiv.org/abs/2305.18290. This configuration controls various aspects of the DPO training process including loss computation, model architecture, and dataset processing.
- beta#
Temperature parameter (β) controlling deviation from reference model. Higher values make training focus more on preference matching. Default: 0.1
- Type
float
- label_smoothing#
Smoothing factor for labels in loss calculation. Helps prevent overconfidence. 0.0 means no smoothing. Default: 0.0
- Type
float
- loss_type#
Type of contrastive loss function to use. Valid options: ‘sigmoid’, ‘hinge’, ‘ipo’, ‘exo_pair’, ‘nca_pair’, ‘robust’, ‘bco_pair’, ‘sppo_hard’, ‘aot’, ‘aot_pair’, ‘apo_zero’, ‘apo_down’. Default: ‘sigmoid’
- Type
LOSS_FN_VARIENTS
- use_weighting#
Whether to apply example weighting in loss calculation. Default: False
- Type
bool
- label_pad_token_id#
Token ID used for padding labels. Default: -100
- Type
int
- padding_value#
Value used for padding sequences. If None, uses model’s default padding token. Default: None
- Type
int | None
- max_length#
Maximum total sequence length (prompt + completion). Default: 512
- Type
int | None
- max_prompt_length#
Maximum length for prompt sequences. Default: 256
- Type
int | None
- max_completion_length#
Maximum length for completion sequences. Auto-calculated as max_length - max_prompt_length if None. Default: None
- Type
int | None
- is_encoder_decoder#
Explicitly set if model is encoder-decoder. Auto-detected if None. Default: None
- Type
bool | None
- disable_dropout#
Whether to disable dropout during training for deterministic behavior. Default: True
- Type
bool
- precompute_ref_log_probs#
Whether to precompute reference model log probabilities before training. Default: False
- Type
bool
- dataset_num_proc#
Number of processes for dataset preprocessing. Default: None (sequential processing)
- Type
int | None
- reference_free#
Whether to use reference-free variant of DPO. Default: False
- Type
bool
- force_use_ref_model#
Force use reference model even when reference_free=True. Default: False
- Type
bool
- sync_ref_model#
Whether to periodically sync reference model with training model. Default: False
- Type
bool
- learning_rate#
Optimizer learning rate. Default: 1e-6
- Type
float
- ref_model_mixup_alpha#
Alpha parameter for mixup between policy and reference models. Default: 0.9
- Type
float
- ref_model_sync_steps#
Number of steps between reference model syncs. Default: 64
- Type
int
- rpo_alpha#
Alpha parameter for Relative Preference Optimization. None disables RPO. Default: None
- Type
float | None
- tools#
Additional tools for training process
- Type
list[dict | Callable] | None
Example
>>> config = DPOConfig( ... beta=0.2, loss_type="ipo", max_length=1024, learning_rate=5e-6 ... )
- beta: float = 0.1#
- dataset_num_proc: Optional[int] = None#
- disable_dropout: bool = True#
- force_use_ref_model: bool = False#
- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- is_encoder_decoder: Optional[bool] = None#
- label_pad_token_id: int = -100#
- label_smoothing: float = 0.0#
- learning_rate: float = 1e-06#
- loss_type: Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = 'sigmoid'#
- max_completion_length: Optional[int] = None#
- max_length: Optional[int] = 512#
- max_prompt_length: Optional[int] = 256#
- model_name: str = 'DPOTrainer'#
- padding_value: Optional[int] = None#
- precompute_ref_log_probs: bool = False#
- ref_model_mixup_alpha: float = 0.9#
- ref_model_sync_steps: int = 64#
- reference_free: bool = False#
- replace(**kwargs)#
- rpo_alpha: Optional[float] = None#
- sync_ref_model: bool = False#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.
- tools: Optional[List[Union[dict, Callable]]] = None#
- use_weighting: bool = False#
- class easydel.__init__.DPOTrainer(arguments: DPOConfig, model: Union[EasyDeLBaseModule, EasyDeLState], reference_model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, processing_class: Optional[Any] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Any] = None, data_collator: Optional[Callable] = None)[source]#
Bases:
TrainerTrainer for Direct Preference Optimization (DPO).
This trainer handles the training, evaluation, and checkpointing of language models using the DPO algorithm. It supports sharding, gradient accumulation, mixed precision training, LoRA, and precomputed reference model log probabilities.
- compute_reference_log_probs(state: EasyDeLState, padded_batch: Dict) tuple[Any, Any][source]#
Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.
- Parameters
state (EasyDeLState) – The EasyDeLState object of the model (used if no reference model is provided).
padded_batch (tp.Dict) – The padded batch of data.
- Returns
A tuple containing the log probabilities for the chosen and rejected responses.
- Return type
tuple[tp.Any, tp.Any]
- configure_dataloaders()[source]#
Returns the training dataloader, potentially with precomputed reference log probabilities.
If precompute_ref_log_probs is enabled, this method computes the reference model’s log probabilities for the chosen and rejected responses in the training dataset and adds them as columns to the dataset.
- Returns
The training dataloader.
- Return type
tensorflow.data.Dataset
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- This method sets up the necessary functions for training and evaluation, including:
Initialization of the model state.
Sharding of the model parameters and optimizer state.
JIT-compilation of the training and evaluation step functions.
- Returns
An object containing the configured functions and other relevant information.
- Return type
- create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Creates a data collection function for batching.
For DPO training, this method simply returns the pre-configured data_collator.
- Parameters
max_sequence_length (int) – The maximum sequence length (not used in this implementation).
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.
- Returns
The data collator function.
- Return type
tp.Callable
- on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#
hook process to call in start of the step.
- static process_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#
- static tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#
Tokenize a row of the dataset.
- Parameters
features (dict[str, str]) – Row of the dataset, should contain the keys “prompt”, “chosen”, and “rejected”.
processing_class (PreTrainedTokenizerBase) – Processing class used to process the data.
max_prompt_length (int or None) – Maximum length of the prompt sequence. If None, the prompt sequence is not truncated.
max_completion_length (int or None) – Maximum length of the completion sequences. If None, the completion sequences are not truncated.
add_special_tokens (bool) – Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If True, the prompt sequence will have a bos token prepended and an eos token appended. In any case, the completion sequences will have an eos token appended.
- Returns
Tokenized sequences with the keys “prompt_input_ids”, “chosen_input_ids”, and `”rejected_input_ids”.
- Return type
dict[str, list[int]]
- class easydel.__init__.DbrxAttentionConfig(attn_pdrop: float = 0, clip_qkv: Optional[float] = 8, kv_n_heads: int = 1, rope_theta: float = 10000.0, **kwargs: Any)[source]#
Bases:
EasyDeLBaseConfigThis is the configuration class to store the attention related configuration of a [DbrxModel].
- Parameters
attn_pdrop (float, optional, defaults to 0.0) – The dropout probability applied to the attention output.
clip_qkv (float, optional, defaults to 8.0) – The clip value applied to the query, key, and value tensors.
kv_n_heads (int, optional, defaults to 1) – The number of attention heads for the key and value tensors.
rope_theta (float, optional, defaults to 10000.0) – The theta value for the rotary position embedding.
- classmethod from_pretrained(pretrained_model_name_or_path: str, **kwargs: Any) PretrainedConfig[source]#
Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.
- Parameters
pretrained_model_name_or_path (str or os.PathLike) –
This can be either:
a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.
a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.
a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.
cache_dir (str or os.PathLike, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.
force_download (bool, optional, defaults to False) – Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.
resume_download – Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
proxies (Dict[str, str], optional) – A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.
token (str or bool, optional) – The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
revision (str, optional, defaults to “main”) –
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.
<Tip>
To test a pull request you made on the Hub, you can pass `revision=”refs/pr/<pr_number>”.
</Tip>
return_unused_kwargs (bool, optional, defaults to False) –
If False, then this function returns just the final configuration object.
If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.
subfolder (str, optional, defaults to “”) – In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.
kwargs (Dict[str, tp.Any], optional) – The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.
- Returns
The configuration object instantiated from this pretrained model.
- Return type
[PretrainedConfig]
Examples:
>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a >>> # derived class: BertConfig >>> config = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased" >>> ) # Download configuration from huggingface.co and cache. >>> config = BertConfig.from_pretrained( ... "./test/saved_model/" >>> ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')* >>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json") >>> config = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased", output_attentions=True, foo=False >>> ) >>> assert config.output_attentions == True >>> config, unused_kwargs = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased", ... output_attentions=True, ... foo=False, ... return_unused_kwargs=True, >>> ) >>> assert config.output_attentions == True >>> assert unused_kwargs == {"foo": False}
- class easydel.__init__.DbrxConfig(d_model: int = 2048, n_heads: int = 16, n_layers: int = 24, max_seq_len: int = 2048, vocab_size: int = 32000, resid_pdrop: float = 0.0, emb_pdrop: float = 0.0, attn_config: Optional[DbrxAttentionConfig] = None, ffn_config: Optional[DbrxFFNConfig] = None, use_cache: bool = True, initializer_range: float = 0.02, output_router_logits: bool = False, router_aux_loss_coef: float = 0.05, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs: Any)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
d_model (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.
n_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.
n_layers (int, optional, defaults to 24) – Number of hidden layers in the Transformer encoder.
max_seq_len (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
vocab_size (int, optional, defaults to 32000) – Vocabulary size of the DBRX model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
emb_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
attn_config ([DbrxAttentionConfig], optional) – The configuration of the attention layer.
ffn_config ([DbrxFFNConfig], optional) – The configuration of the feed forward layer.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
output_router_logits (bool, optional, defaults to False) – Whether or not to output the router logits.
router_aux_loss_coef (float, optional, defaults to 0.05) – The coefficient of the router auxiliary loss.
- attribute_map: dict[str, str] = {'hidden_size': 'd_model', 'max_position_embeddings': 'max_seq_len', 'num_attention_heads': 'n_heads', 'num_hidden_layers': 'n_layers'}#
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model parameters.
These rules define how parameters should be sharded across devices when using model parallelism.
- Parameters
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
A tuple of partition rules for different parameter patterns.
- Return type
Tuple
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size for frequency-based position embeddings.
- Returns
The maximum position embedding size, falling back to max_seq_len if not explicitly set.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size for mask-based position embeddings.
- Returns
The maximum position embedding size, falling back to max_seq_len if not explicitly set.
- Return type
int
- model_type: str = 'dbrx'#
- class easydel.__init__.DbrxFFNConfig(ffn_act_fn: Optional[dict] = None, ffn_hidden_size: int = 3584, moe_num_experts: int = 4, moe_top_k: int = 1, moe_jitter_eps: Optional[float] = None, moe_loss_weight: float = 0.01, moe_normalize_expert_weights: Optional[float] = 1, uniform_expert_assignment: bool = False, **kwargs: Any)[source]#
Bases:
EasyDeLBaseConfigThis is the configuration class to store the feed forward related configuration of a [DbrxModel].
- Parameters
ffn_act_fn (dict, optional) – The activation function configuration for the feed-forward network.
ffn_hidden_size (int, optional, defaults to 3584) – The hidden size of the feed-forward network.
moe_num_experts (int, optional, defaults to 4) – The number of experts in the Mixture-of-Experts (MoE) layer.
moe_top_k (int, optional, defaults to 1) – The number of top experts to use in the MoE layer.
moe_jitter_eps (float, optional) – The jitter epsilon value for the MoE layer.
moe_loss_weight (float, optional, defaults to 0.01) – The loss weight for the MoE auxiliary loss.
moe_normalize_expert_weights (float, optional, defaults to 1.0) – The normalization factor for the expert weights in the MoE layer.
uniform_expert_assignment (bool, optional, defaults to False) – Whether to use uniform expert assignment in the MoE layer.
- classmethod from_pretrained(pretrained_model_name_or_path: str, **kwargs: Any) EasyDeLBaseConfig[source]#
Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.
- Parameters
pretrained_model_name_or_path (str or os.PathLike) –
This can be either:
a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.
a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.
a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.
cache_dir (str or os.PathLike, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.
force_download (bool, optional, defaults to False) – Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.
resume_download – Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
proxies (Dict[str, str], optional) – A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.
token (str or bool, optional) – The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
revision (str, optional, defaults to “main”) –
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.
<Tip>
To test a pull request you made on the Hub, you can pass `revision=”refs/pr/<pr_number>”.
</Tip>
return_unused_kwargs (bool, optional, defaults to False) –
If False, then this function returns just the final configuration object.
If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.
subfolder (str, optional, defaults to “”) – In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.
kwargs (Dict[str, tp.Any], optional) – The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.
- Returns
The configuration object instantiated from this pretrained model.
- Return type
[PretrainedConfig]
Examples:
>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a >>> # derived class: BertConfig >>> config = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased" >>> ) # Download configuration from huggingface.co and cache. >>> config = BertConfig.from_pretrained( ... "./test/saved_model/" >>> ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')* >>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json") >>> config = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased", output_attentions=True, foo=False >>> ) >>> assert config.output_attentions == True >>> config, unused_kwargs = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased", ... output_attentions=True, ... foo=False, ... return_unused_kwargs=True, >>> ) >>> assert config.output_attentions == True >>> assert unused_kwargs == {"foo": False}
- class easydel.__init__.DbrxForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.DbrxForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.DbrxModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleBase DBRX Model outputting raw hidden-states.
This model is a Transformer-based model with a mixture of experts (MoE) architecture, implementing the DBRX architecture as described in the original paper.
The model uses specialized attention modules and a router-based MoE FFN layer.
- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- class easydel.__init__.DeepseekV2Config(vocab_size=102400, hidden_size=4096, intermediate_size=11008, moe_intermediate_size=1407, num_hidden_layers=30, num_attention_heads=32, num_key_value_heads=32, n_shared_experts=None, n_routed_experts=None, ep_size=1, routed_scaling_factor=1.0, kv_lora_rank=512, q_lora_rank=1536, qk_rope_head_dim=64, v_head_dim=128, qk_nope_head_dim=128, topk_method='gready', n_group=None, topk_group=None, num_experts_per_tok=None, moe_layer_freq=1, first_k_dense_replace=0, norm_topk_prob=False, scoring_func='softmax', aux_loss_alpha=0.001, seq_aux=True, hidden_act='silu', max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=None, bos_token_id=100000, eos_token_id=100001, pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, rope_scaling: Dict[str, Union[str, float]] = None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 102400) – Vocabulary size of the DeepseekV2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 11008) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
moe_intermediate_size (int, optional, defaults to 1407) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the MoE layer.
num_hidden_layers (int, optional, defaults to 30) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional, defaults to 32) – Number of key and value heads for each attention layer in the Transformer encoder.
n_shared_experts (int, optional) – Number of shared experts.
n_routed_experts (int, optional) – Number of routed experts.
ep_size (int, optional, defaults to 1) – Expert parallel size.
routed_scaling_factor (float, optional, defaults to 1.0) – Routed scaling factor.
kv_lora_rank (int, optional, defaults to 512) – KV LoRA rank.
q_lora_rank (int, optional, defaults to 1536) – Q LoRA rank.
qk_rope_head_dim (int, optional, defaults to 64) – QK rope head dimension.
v_head_dim (int, optional, defaults to 128) – V head dimension.
qk_nope_head_dim (int, optional, defaults to 128) – QK nope head dimension.
topk_method (str, optional, defaults to “gready”) – Top-k method.
n_group (int, optional) – Number of groups.
topk_group (int, optional) – Top-k group.
num_experts_per_tok (int, optional) – Number of experts per token.
moe_layer_freq (int, optional, defaults to 1) – MoE layer frequency.
first_k_dense_replace (int, optional, defaults to 0) – First k dense replace.
norm_topk_prob (bool, optional, defaults to False) – Whether to normalize top-k probabilities.
scoring_func (str, optional, defaults to “softmax”) – Scoring function.
aux_loss_alpha (float, optional, defaults to 0.001) – Auxiliary loss alpha.
seq_aux (bool, optional, defaults to True) – Whether to use sequence auxiliary loss.
hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional) – The index of the padding token in the vocabulary.
bos_token_id (int, optional, defaults to 100000) – The index of the beginning of sequence token in the vocabulary.
eos_token_id (int, optional, defaults to 100001) – The index of the end of sequence token in the vocabulary.
pretraining_tp (int, optional, defaults to 1) – Pretraining TP.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
attention_bias (bool, optional, defaults to False) – Whether to use attention bias.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
use_scan_mlp (bool, optional, defaults to False) – Whether to use scan for MLP.
scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size for scan MLP.
bits (int, optional) – The number of bits to quantize the model to.
rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The rope scaling configuration.
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- static get_weight_decay_exclusions()[source]#
Returns a tuple of parameter names for which weight decay should be excluded.
- Returns
An empty tuple, indicating no weight decay exclusions.
- Return type
tuple
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size for frequency-based position embeddings.
- Returns
The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size for mask-based position embeddings.
- Returns
The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.
- Return type
int
- model_type: str = 'deepseek_v2'#
- class easydel.__init__.DeepseekV2ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleDeepseekV2 model with a language modeling head for causal language modeling tasks.
This model extends the base DeepseekV2Model by adding a linear language modeling head on top of the transformer model. It’s designed for generative tasks and can be used for text generation.
- class easydel.__init__.DeepseekV2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- class easydel.__init__.DeepseekV3Config(vocab_size=129280, hidden_size=7168, intermediate_size=18432, moe_intermediate_size=2048, num_hidden_layers=61, num_nextn_predict_layers=1, num_attention_heads=128, num_key_value_heads=128, n_shared_experts=1, n_routed_experts=256, ep_size=1, routed_scaling_factor=2.5, kv_lora_rank=512, q_lora_rank=1536, qk_rope_head_dim=64, v_head_dim=128, qk_nope_head_dim=128, topk_method='noaux_tc', n_group=8, topk_group=4, num_experts_per_tok=8, moe_layer_freq=1, first_k_dense_replace=3, norm_topk_prob=True, scoring_func='sigmoid', aux_loss_alpha=0.001, seq_aux=True, hidden_act='silu', max_position_embeddings=4096, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=None, bos_token_id=0, eos_token_id=1, pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, **kwargs)[source]#
Bases:
EasyDeLBaseConfigThis is the configuration class to store the configuration of a [DeepseekV3Model]. It is used to instantiate an DeepSeek model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the DeepSeek-V3. Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information. :param vocab_size: Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
inputs_ids passed when calling [DeepseekV3Model]
- Parameters
hidden_size (int, optional, defaults to 4096) – Dimension of the hidden representations.
intermediate_size (int, optional, defaults to 11008) – Dimension of the MLP representations.
moe_intermediate_size (int, optional, defaults to 1407) – Dimension of the MoE representations.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer decoder.
num_nextn_predict_layers (int, optional, defaults to 1) – Number of nextn predict layers in the DeepSeekV3 Model.
num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer decoder.
n_shared_experts (int, optional, defaults to None) – Number of shared experts, None means dense model.
n_routed_experts (int, optional, defaults to None) – Number of routed experts, None means dense model.
routed_scaling_factor (float, optional, defaults to 1.0) – Scaling factor or routed experts.
topk_method (str, optional, defaults to gready) – Topk method used in routed gate.
n_group (int, optional, defaults to None) – Number of groups for routed experts.
topk_group (int, optional, defaults to None) – Number of selected groups for each token(for each token, ensuring the selected experts is only within topk_group groups).
num_experts_per_tok (int, optional, defaults to None) – Number of selected experts, None means dense model.
moe_layer_freq (int, optional, defaults to 1) – The frequency of the MoE layer: one expert layer for every moe_layer_freq - 1 dense layers.
first_k_dense_replace (int, optional, defaults to 0) –
- Number of dense layers in shallow layers(embed->dense->dense->…->dense->moe->moe…->lm_head).
--k dense layers–/
norm_topk_prob (bool, optional, defaults to False) – Whether to normalize the weights of the routed experts.
scoring_func (str, optional, defaults to ‘softmax’) – Method of computing expert weights.
aux_loss_alpha (float, optional, defaults to 0.001) – Auxiliary loss weight coefficient.
= (seq_aux) – Whether to compute the auxiliary loss for each individual sample.
num_key_value_heads (int, optional) – This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads.
hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) in the decoder.
max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (float, optional, defaults to 1e-06) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional) – Padding token id.
bos_token_id (int, optional, defaults to 1) – Beginning of stream token id.
eos_token_id (int, optional, defaults to 2) – End of stream token id.
pretraining_tp (int, optional, defaults to 1) – Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is necessary to ensure exact reproducibility of the pretraining results. Please refer to [this issue](pytorch/pytorch#76232).
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie weight embeddings
rope_theta (float, optional, defaults to 10000.0) – The base period of the RoPE embeddings.
rope_scaling (Dict, optional) – Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is {“type”: strategy name, “factor”: scaling factor}. When using this flag, don’t update max_position_embeddings to the expected new maximum.
attention_bias (bool, defaults to False, optional, defaults to False) – Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
`python >>> from transformers import DeepseekV3Model, DeepseekV3Config >>> # Initializing a Deepseek-V3 style configuration >>> configuration = DeepseekV3Config() >>> # Accessing the model configuration >>> configuration = model.config `- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for model parameters.
These rules define how parameters should be sharded across devices when using model parallelism.
- Parameters
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
A tuple of partition rules for different parameter patterns.
- Return type
Tuple
- keys_to_ignore_at_inference = ['past_key_values']#
- model_type: str = 'deepseek_v3'#
- class easydel.__init__.DeepseekV3ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleDeepseekV3 model with a language modeling head for causal language modeling tasks.
This model extends the base DeepseekV3Model by adding a linear language modeling head on top of the transformer model. It incorporates Mixture of Experts (MoE) architecture and is designed for generative tasks and text generation.
- class easydel.__init__.DeepseekV3Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- class easydel.__init__.EasyDeLBackends(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,EnumEnumeration of JAX backend types supported by EasyDeL.
Specifies the target hardware device type for JAX computations.
- CPU#
Use the CPU backend.
- GPU#
Use the GPU backend.
- TPU#
Use the TPU backend.
- CPU = 'cpu'#
- GPU = 'gpu'#
- TPU = 'tpu'#
- class easydel.__init__.EasyDeLBaseConfig(axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), attn_mechanism: ~typing.Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa'] = 'vanilla', blocksize_k: int = 128, blocksize_q: int = 128, blocksize_b: int = 1, partition_axis: ~eformer.escale.partition.constraints.PartitionAxis = PartitionAxis( data_parallel_axis : dp fully_sharded_data_parallel_axis : fsdp tensor_parallel_axis : tp sequence_parallel_axis : sp expert_parallel_axis : ep batch_axis : ('fsdp', 'dp') sequence_axis : sp query_sequence_axis : sp head_axis : tp key_sequence_axis : sp hidden_state_axis : tp mlp_intermediate_axis : tp vocab_axis : tp expert_axis : ep expert_gate_axis : None attention_dim_axis : None bias_head_sequence_axis : None bias_key_sequence_axis : None generation_batch_axis : None generation_query_sequence_axis : None generation_head_axis : tp generation_key_sequence_axis : sp generation_attention_dim_axis : None ), shard_attention_computation: bool = True, use_sharded_kv_caching: bool = False, use_sharding_constraint: bool = False, backend: ~typing.Optional[~easydel.__init__.infra.etils.EasyDeLBackends] = None, platform: ~typing.Optional[~easydel.__init__.infra.etils.EasyDeLPlatforms] = None, easy_method: ~typing.Literal['train', 'serve', 'convert'] = 'train', bits: ~typing.Optional[int] = None, scan_ring_attention: bool = True, scan_attention_layers: bool = False, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, sequence_axis_name: str = 'sp', gradient_checkpointing: ~easydel.__init__.infra.etils.EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, precompute_masks: bool = True, kv_cache_quantization_method: ~easydel.__init__.infra.etils.EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, kv_cache_quantization_blocksize: int = 64, quantization_method: ~easydel.__init__.infra.etils.EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, quantization_pattern: str = '.*', quantization_blocksize: int = 64, kv_cache_sharding_sequence_axis_name: ~typing.Union[str, ~typing.Tuple[str, ...]] = 'sp', flash_attention_backward_pass_impl: ~typing.Literal['triton', 'xla'] = 'triton', attn_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, attn_softmax_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, fcm_max_ratio: float = 0.0, fcm_min_ratio: float = 0.0, hardware_abstraction: bool = False, pallas_m_block_size: int = 128, pallas_k_block_size: int = 128, pallas_n_block_size: int = 128, **kwargs)[source]#
Bases:
PretrainedConfigInitialize the configuration for EasyDeL. :param axis_dims: Dimensions of the axes. Default is (1, -1, 1, 1). :type axis_dims: tp.Sequence[int] :param axis_names: Names of the axes. Default is (“dp”, “fsdp”, “tp”, “sp”). :type axis_names: tp.Sequence[str] :param attn_mechanism: Attention mechanism to use. Default is DEFAULT_ATTENTION_MECHANISM. :type attn_mechanism: AVAILABLE_ATTENTION_MECHANISMS :param blocksize_k: Block size for key. Default is 128. :type blocksize_k: int :param blocksize_q: Block size for query. Default is 128. :type blocksize_q: int :param blocksize_b: Block size for batch. Default is 1. :type blocksize_b: int :param partition_axis: Partition axis configuration. Default is PartitionAxis(). :type partition_axis: PartitionAxis :param shard_attention_computation: Whether to shard attention computation. Default is True. :type shard_attention_computation: bool :param use_sharded_kv_caching: Whether to use sharded key-value caching. Default is False. :type use_sharded_kv_caching: bool :param use_sharding_constraint: Whether to use sharding constraint. Default is False. :type use_sharding_constraint: bool :param backend: Backend to use. Default is None. :type backend: tp.Optional[EasyDeLBackends] :param platform: Platform to use. Default is None. :type platform: tp.Optional[EasyDeLPlatforms] :param easy_method: Method to use. Default is EasyMethod.TRAIN. :type easy_method: tp.Literal[“train”, “serve”, “convert”] :param bits: Number of bits for quantization. Default is None. :type bits: tp.Optional[int] :param scan_ring_attention: Whether to scan ring attention. Default is True. :type scan_ring_attention: bool :param scan_attention_layers: Whether to scan attention layers. Default is False. :type scan_attention_layers: bool :param use_scan_mlp: Whether to use scan MLP. Default is False. :type use_scan_mlp: bool :param scan_mlp_chunk_size: Chunk size for scan MLP. Default is 1024. :type scan_mlp_chunk_size: int :param sequence_axis_name: Name of the attention axis. Default is “sp”. :type sequence_axis_name: str :param gradient_checkpointing: Gradient checkpointing method. Default is EasyDeLGradientCheckPointers.NONE. :type gradient_checkpointing: EasyDeLGradientCheckPointers :param kv_cache_quantization_method: Key-value cache quantization method. Default is EasyDeLQuantizationMethods.NONE. :type kv_cache_quantization_method: EasyDeLQuantizationMethods :param kv_cache_quantization_blocksize: Block size for key-value cache quantization. Default is 64. :type kv_cache_quantization_blocksize: int :param quantization_method: Quantization method. Default is EasyDeLQuantizationMethods.NONE. :type quantization_method: EasyDeLQuantizationMethods :param quantization_pattern: Pattern for quantization. Default is “.*”. :type quantization_pattern: str :param quantization_blocksize: Block size for quantization. Default is 64. :type quantization_blocksize: int :param kv_cache_sharding_sequence_axis_name: Name of the key-value cache sharding sequence axis. Default is “sp”. :type kv_cache_sharding_sequence_axis_name: tp.Union[str, tp.Tuple[str, …]] :param flash_attention_backward_pass_impl: Implementation for flash attention backward pass. Default is “triton”. :type flash_attention_backward_pass_impl: tp.Literal[“triton”, “xla”] :param attn_dtype: Data type for attention. Default is device half. :type attn_dtype: jnp.dtype :param attn_softmax_dtype: Data type for softmax ops in attention. Default is jnp.float32. :type attn_softmax_dtype: jnp.dtype :param fcm_max_ratio: Maximum ratio for FCM. Default is 0.0. :type fcm_max_ratio: float :param fcm_min_ratio: Minimum ratio for FCM. Default is 0.0. :type fcm_min_ratio: float :param hardware_abstraction: Whether to use hardware abstraction. Default is DEFAULT_HARDWARE_ABSTRACTION. :type hardware_abstraction: bool :param pallas_m_block_size: Block size for Pallas M. Default is DEFAULT_PALLAS_M_BLOCK_SIZE. :type pallas_m_block_size: int :param pallas_k_block_size: Block size for Pallas K. Default is DEFAULT_PALLAS_K_BLOCK_SIZE. :type pallas_k_block_size: int :param pallas_n_block_size: Block size for Pallas N. Default is DEFAULT_PALLAS_N_BLOCK_SIZE. :type pallas_n_block_size: int :param **kwargs: Additional keyword arguments.
- Raises
Warning – If kv_cache_quantization_method is not NONE and use_sharded_kv_caching is True.
- add_basic_configurations(axis_dims: Sequence[int] = Ellipsis, dcn_axis_dims: Optional[Sequence[int]] = Ellipsis, axis_names: Sequence[str] = Ellipsis, attn_mechanism: Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa'] = Ellipsis, blocksize_k: int = Ellipsis, blocksize_q: int = Ellipsis, blocksize_b: int = Ellipsis, partition_axis: PartitionAxis = Ellipsis, shard_attention_computation: bool = Ellipsis, use_sharded_kv_caching: bool = Ellipsis, backend: Optional[EasyDeLBackends] = Ellipsis, platform: Optional[EasyDeLPlatforms] = Ellipsis, easy_method: Literal['train', 'serve', 'convert'] = Ellipsis, bits: Optional[int] = Ellipsis, scan_ring_attention: bool = Ellipsis, scan_attention_layers: bool = Ellipsis, use_sharding_constraint: bool = Ellipsis, use_scan_mlp: bool = Ellipsis, scan_mlp_chunk_size: int = Ellipsis, sequence_axis_name: str = Ellipsis, gradient_checkpointing: EasyDeLGradientCheckPointers = Ellipsis, precompute_masks: bool = Ellipsis, kv_cache_quantization_method: EasyDeLQuantizationMethods = Ellipsis, kv_cache_quantization_blocksize: int = Ellipsis, quantization_method: EasyDeLQuantizationMethods = Ellipsis, quantization_blocksize: int = Ellipsis, quantization_pattern: str = Ellipsis, kv_cache_sharding_sequence_axis_name: Union[str, Tuple[str, ...]] = Ellipsis, flash_attention_backward_pass_impl: Literal['triton', 'xla'] = Ellipsis, attn_dtype: dtype = Ellipsis, attn_softmax_dtype: dtype = Ellipsis, hardware_abstraction: bool = Ellipsis, pallas_m_block_size: int = Ellipsis, pallas_k_block_size: int = Ellipsis, pallas_n_block_size: int = Ellipsis, **kwargs)[source]#
It initializes all the attributes of an object, and it’s called when you create a new instance of that class.
- Parameters
axis_dims (tp.Sequence[int], optional) – Specify the number of dimensions for each axis. Defaults to (1, -1, 1, 1).
axis_names (tp.Sequence[str], optional) – Set the names of the axes. Defaults to (“dp”, “fsdp”, “tp”, “sp”).
attn_mechanism (AVAILABLE_ATTENTION_MECHANISMS, optional) – attention mechanism to use. Defaults to DEFAULT_ATTENTION_MECHANISM.
blocksize_k (int, optional) – block size of key_states. Defaults to 128.
blocksize_q (int, optional) – block size of query_states. Defaults to 128.
blocksize_b (int, optional) – block size of bias. Defaults to 1.
partition_axis (PartitionAxis, optional) – PartitionAxis is new module used for partitioning arrays in easydel. Defaults to PartitionAxis().
shard_attention_computation (bool, optional) – whenever to use shard_map for attention. Defaults to True.
use_sharded_kv_caching (bool, optional) – whenever to use shard_map and sharding for key and value. Defaults to True.
backend (tp.Optional[EasyDeLBackends], optional) – Specify the backend to use. Defaults to None.
platform (tp.Optional[EasyDeLPlatforms], optional) – Specify the platform to used to use. Defaults to None.
easy_method (tp.Literal["train", "serve", "convert"], optional) – easydel Quantization Method to be applied for. Defaults to EasyMethod.TRAIN.
bits (tp.Optional[int], optional) – Model bits for quantization. Defaults to None.
scan_ring_attention (bool, optional) – Whether to use can for ring attention. Defaults to True.
scan_attention_layers (bool, optional) – Whether to use can for attention layers. Defaults to False.
use_sharding_constraint (bool, optional) – whether to use sharding constraint for the arrays. Defaults to False.
use_scan_mlp (bool, optional) – Determine whether to use scan_mlp or not. Defaults to False.
scan_mlp_chunk_size (int, optional) – Size of chunks in scan MLP. Defaults to 1024.
sequence_axis_name (str, optional) – Name of the attention axis name. Defaults to “sp”.
gradient_checkpointing (EasyDeLQuantizationMethods, optional) – Gradient Checkpointing method for created or loaded module (applied on mlp and attn layers most of the times).
kv_cache_quantization_method (EasyDeLQuantizationMethods, optional) – key and value quantization type. Defaults to EasyDeLQuantizationMethods.NONE.
kv_cache_quantization_blocksize (int, optional) – size of kv cache quantization. Defaults to 64.
quantization_method (EasyDeLQuantizationMethods, optional) – linear modules quantization type. Defaults to EasyDeLQuantizationMethods.NONE.
quantization_blocksize (int, optional) – size of linear quantization. Defaults to 64.
quantization_pattern (str) – re pattern to be used for quantizing layers.
kv_cache_sharding_sequence_axis_name (tp.Union[str, tp.Tuple[str, ...]], optional) – axis name to target for sharding sequences. Defaults to “sp”.
flash_attention_backward_pass_impl (tp.Literal["triton", "xla"], optional) – Specify the backward pass kernel for flash attention. Defaults to “triton”.
attn_dtype (jnp.dtype, optional) – Data type for attention computations. Defaults to device half.
attn_softmax_dtype (jnp.dtype, optional) – Data type for softmax in attention op computations. Defaults to jnp.float32.
fcm_max_ratio (float, optional) – Maximum ratio for flash cross attention. Defaults to 0.0.
fcm_min_ratio (float, optional) – Minimum ratio for flash cross attention. Defaults to 0.0.
hardware_abstraction (bool, optional) – whenever to switch to custom pallas kernels instead of JAX. Defaults to DEFAULT_HARDWARE_ABSTRACTION.
pallas_m_block_size (int, optional) – block size m dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_M_BLOCK_SIZE.
pallas_k_block_size (int, optional) – block size k dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_K_BLOCK_SIZE.
pallas_n_block_size (int, optional) – block size n dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_N_BLOCK_SIZE.
- static create_mesh(axis_dims: Sequence[int] = (1, -1, 1, 1), axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), dcn_axis_dims: Optional[Sequence[int]] = None, process_is_granule: bool = False, should_sort_granules_by_key: bool = True, allow_split_physical_axes: bool = True, backend: Optional[str] = None)[source]#
The create_mesh function creates a mesh object that can be used to shard arrays.
- Returns
A mesh object
- classmethod from_pretrained(pretrained_model_name_or_path: Union[str, PathLike], cache_dir: Optional[Union[str, PathLike]] = None, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = 'main', **kwargs) PretrainedConfig[source]#
Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.
- Parameters
pretrained_model_name_or_path (str or os.PathLike) –
This can be either:
a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.
a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.
a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.
cache_dir (str or os.PathLike, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.
force_download (bool, optional, defaults to False) – Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.
resume_download – Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
proxies (Dict[str, str], optional) – A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.
token (str or bool, optional) – The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
revision (str, optional, defaults to “main”) –
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.
<Tip>
To test a pull request you made on the Hub, you can pass `revision=”refs/pr/<pr_number>”.
</Tip>
return_unused_kwargs (bool, optional, defaults to False) –
If False, then this function returns just the final configuration object.
If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.
subfolder (str, optional, defaults to “”) – In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.
kwargs (Dict[str, tp.Any], optional) – The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.
- Returns
The configuration object instantiated from this pretrained model.
- Return type
[PretrainedConfig]
Examples:
>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a >>> # derived class: BertConfig >>> config = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased" >>> ) # Download configuration from huggingface.co and cache. >>> config = BertConfig.from_pretrained( ... "./test/saved_model/" >>> ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')* >>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json") >>> config = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased", output_attentions=True, foo=False >>> ) >>> assert config.output_attentions == True >>> config, unused_kwargs = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased", ... output_attentions=True, ... foo=False, ... return_unused_kwargs=True, >>> ) >>> assert config.output_attentions == True >>> assert unused_kwargs == {"foo": False}
- get_axis_dims() Sequence[int][source]#
The get_axis_dims function returns a sequence of integers representing the dimensions of each axis.
- Parameters
self – Represent the instance of the class
- Returns
The dimensions of the axes
- get_axis_names() Sequence[str][source]#
The get_axis_names function returns a list of the names of the axes.
- Parameters
self – Represent the instance of the class
- Returns
A list of the names of all axes
- get_backend() str[source]#
The get_backend function returns the backend that is currently being used. If no backend has been set, it will return the default JAX backend.
- Parameters
self – Bind the method to an object
- Returns
The backend platform
- get_basic_frequencies(head_size: Optional[int] = None, rotary_dim: Optional[int] = None, base: Optional[float] = None) Any[source]#
Get basic frequencies for rotary embeddings.
- Parameters
head_size – Size of attention heads (defaults to self.head_dim)
rotary_dim – Dimension for rotary embeddings (defaults to head_size)
base – Base value for frequency computation (defaults to self.rope_theta)
- Returns
ModuleCaches instance containing computed frequencies
- get_basic_inv_frequencies(head_size: Optional[int] = None, rotary_dim: Optional[int] = None, base: Optional[float] = None, partial_rotary_factor: float = 1.0) Any[source]#
Get basic inv frequencies for rotary embeddings.
- Parameters
head_size – Size of attention heads (defaults to self.head_dim)
rotary_dim – Dimension for rotary embeddings (defaults to head_size)
base – Base value for frequency computation (defaults to self.rope_theta)
- Returns
ModuleCaches instance containing computed frequencies
- get_basic_rope(dtype: Union[Array, ndarray, bool, number], head_size: int, rotary_dim: Optional[int] = None, is_neox_style: bool = True, base: Optional[float] = None)[source]#
Get basic rotary position embeddings.
- Parameters
dtype – Data type for the embeddings
head_size – Size of attention heads
rotary_dim – Dimension for rotary embeddings (defaults to head_size)
is_neox_style – Whether to use NeoX style embeddings
base – Base value for frequency computation (defaults to self.rope_theta)
- Returns
Rotary position embeddings func
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
- property granted_mask_max_position_embedding: int#
- property mesh#
The mesh property is a helper property that creates a Mesh object from the axis_dims and axis_names attributes of an object, which are assumed to be lists of integers and strings, respectively. The platform attribute is also used if it exists.
- Parameters
self – Refer to the object itself
- Returns
A jaxMesh
- read_basics_from_config(config: EasyDeLBaseConfig)[source]#
- class easydel.__init__.EasyDeLBaseModule(*args: Any, **kwargs: Any)[source]#
Bases:
Module,BaseModuleProtocol,EasyBridgeMixin,EasyGenerationMixinBase class for EasyDeL modules, providing common functionalities for model initialization, parameter handling, and integration with the EasyDeL ecosystem.
- apply_lora_to_layers(lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = False, rngs: Optional[Rngs] = None) SELF[source]#
Applies Low-Rank Adaptation (LoRA) layers to the specified linear layers within the module.
Replaces targeted flax.linen.Dense layers with easydel.layers.lora.LoraLinear layers, initializing the LoRA matrices (A and B).
- Parameters
lora_rank (int) – The rank of the LoRA decomposition.
lora_pattern (tp.Optional[str], optional) – A regular expression to match the names of the Dense layers to apply LoRA to. If None, applies to common attention and MLP layers. Defaults to None.
verbose (bool, optional) – If True, prints information about which layers are being modified. Defaults to False.
rngs (tp.Optional[nn.Rngs], optional) – JAX random number generators for initializing LoRA matrices. If None, default RNGs might be used. Defaults to None.
- Returns
The module instance with LoRA layers applied.
- Return type
SELF
- property causal_mask: Array#
Retrieves or computes the basic causal attention mask from the configuration.
Uses self.config.get_basic_causal_mask() and caches the result.
- Returns
The causal attention mask, potentially cached.
- Return type
jnp.ndarray
- compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#
Computes the loss for the model given a batch of inputs and labels.
This method performs a forward pass using the provided batch arguments, then calculates the loss using the determined loss_function. It handles potential label inference (e.g., using input_ids as labels for Causal LM) and default loss configurations.
- Parameters
labels (tp.Optional[chex.Array], optional) – The target labels. If None and the task is Causal LM, input_ids from the batch might be used. Defaults to None.
loss_config (tp.Optional[LossConfig], optional) – Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.
loss_kwargs (tp.Optional[tp.Dict], optional) – Additional keyword arguments to pass directly to the loss function. Defaults to None.
**batch – Keyword arguments representing the input batch (e.g., input_ids, attention_mask).
- Returns
- A tuple containing:
The model’s output ( Pytree typically including logits, hidden states etc.)
A LossMetrics object containing the calculated loss and potentially other metrics.
- Return type
tp.Tuple[tp.Any, LossMetrics]
- Raises
AssertionError – If labels are required for the loss function but are not provided or inferred.
AssertionError – If sequence classification loss is used without num_labels in the config.
- float(change_runtime_dtype: bool = True) SELF[source]#
Converts the module’s parameters to single-precision (float32).
Optionally also changes the runtime computation dtype (self.dtype) to float32.
- Parameters
change_runtime_dtype (bool) – If True, also sets self.dtype to jnp.float32. Defaults to True.
- Returns
The module instance with parameters (and potentially runtime dtype) set to float32.
- Return type
SELF
- property frequencies: Array#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- fully_gather() SELF[source]#
Applies JAX sharding constraints to gather all parameters onto the host or a single device.
This function marks all parameters to have no sharding (PartitionSpec()). It uses jax.jit with out_shardings to enforce these gathering constraints.
- Returns
The model instance with gathering constraints applied.
- Return type
SELF
- fully_shard(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None) SELF[source]#
Applies JAX sharding constraints to all parameters based on the partition rules.
This function ensures that parameters are explicitly marked with their intended sharding, which can be useful for performance and correctness checks. It uses jax.jit with out_shardings to enforce the constraints.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses config rules. Defaults to None.
- Returns
The model instance with sharding constraints applied.
- Return type
SELF
- gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#
Gathers the model’s parameters from potentially distributed devices to the host or a single device.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules used to determine how parameters were originally sharded. If None, uses config rules. Defaults to None.
mesh (tp.Optional[Mesh], optional) – JAX device mesh from which to gather. If None, uses config mesh. Defaults to None.
overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional) – Additional functions to apply, potentially overriding default gathering for specific parameters. Defaults to None.
- Returns
The model instance with gathered parameters.
- Return type
SELF
- get_static_arguments() Tuple[source]#
Returns a tuple of static arguments required by the module’s __call__ method.
Static arguments are those that don’t change across calls and can be potentially cached or handled differently by JIT compilation. This base implementation returns an empty tuple. Subclasses should override this if they have static arguments.
- Returns
A tuple containing static arguments.
- Return type
tp.Tuple
- property graphdef: Union[NodeDef[Node], NodeRef[Node]]#
Returns the graph definition (structure without parameters) of the module.
Uses flax.nnx.split to separate the graph definition from the state (parameters).
- Returns
The graph definition of the module.
- Return type
nn.GraphDef
- property graphother: State[Key, VariableState[Any]]#
Returns any other state variables in the module (non-parameters).
Uses flax.nnx.split to separate non-parameter state variables.
- Returns
The graph state containing non-parameter variables.
- Return type
nn.GraphState
- property graphstate: State[Key, VariableState[Any]]#
Returns the graph state (parameters) of the module.
Uses flax.nnx.split to separate the state (parameters) from the graph definition.
- Returns
The graph state containing the module’s parameters.
- Return type
nn.GraphState
- property graphtree_params_shape: Dict#
Computes and returns the shapes of the module’s parameters as a nested dictionary.
It uses nnx.eval_shape to determine the shapes without actual computation, then extracts the shape information from the resulting graph state.
- Returns
A nested dictionary mirroring the parameter structure, containing their shapes.
- Return type
tp.Dict
- property graphtree_shape: Dict#
Computes and returns the shapes of all state variables (including non-parameters) in the module.
Uses nnx.eval_shape on the entire module state (parameters and others) and extracts the shape information.
- Returns
A nested dictionary mirroring the module’s state structure, containing the shapes.
- Return type
tp.Dict
- half(change_runtime_dtype: bool = True) SELF[source]#
Converts the module’s parameters to half-precision (float16).
Optionally also changes the runtime computation dtype (self.dtype) to float16.
- Parameters
change_runtime_dtype (bool) – If True, also sets self.dtype to jnp.float16. Defaults to True.
- Returns
The module instance with parameters (and potentially runtime dtype) set to float16.
- Return type
SELF
- property inv_frequencies: Array#
Retrieves or computes the inv-frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_inv_frequencies() and caches the result.
- Returns
The inv-frequency components, potentially cached.
- Return type
jnp.ndarray
- classmethod lazy_init(*args, **kwargs) SELF[source]#
Performs a “lazy” initialization using nnx.eval_shape.
This initializes the module structure and determines parameter shapes without actually allocating memory for the parameters. Useful for inspecting the model structure or preparing for sharding.
- Parameters
*args – Positional arguments passed to the class constructor.
**kwargs – Keyword arguments passed to the class constructor.
- Returns
A module instance with initialized structure but potentially abstract parameters.
- Return type
SELF
- property loss_function#
Determines and returns the appropriate loss function based on the configuration or model type.
It prioritizes config.loss_type, then self.loss_type, and finally tries to infer the loss type from the class name. If no suitable loss function is found, it defaults to ForCausalLMLoss and issues a warning.
- Returns
The selected loss function (e.g., ForCausalLMLoss, ForSequenceClassificationLoss).
- Return type
tp.Callable
- merge_lora_params(pytree: Dict) SELF[source]#
Merges LoRA parameters from a pytree into the base model’s parameters.
- Parameters
pytree (tp.Dict) – A dictionary (pytree) containing the LoRA parameters (A and B matrices) structured similarly to the base model’s parameters.
- Returns
The module instance with LoRA parameters merged into the base weights.
- Return type
SELF
- merge_params(tree)[source]#
Merges a given parameter state tree back into the module.
Reconstructs the module using its existing graph definition and ‘other’ state, but replaces the parameter state with the provided tree.
- Parameters
tree – A pytree (likely a nn.GraphState) containing the parameters to merge.
- Returns
The module instance with the new parameters merged in.
- Return type
- merge_params_dict(params_dict: Dict) SELF[source]#
Merges parameters from a dictionary back into the module’s state.
Updates the module’s current parameter state with values from the provided dictionary.
- Parameters
params_dict (tp.Dict) – A nested dictionary containing the parameters to merge. The structure should match the module’s parameter structure.
- Returns
The module instance with the parameters from the dictionary merged in.
- Return type
SELF
- Raises
KeyError – If a key from params_dict is not found in the module’s current state.
- property mesh: Mesh#
Retrieves the JAX device mesh from the module’s configuration.
- Returns
The device mesh defined in self.config.mesh.
- Return type
- property model_task: Optional[str]#
Returns the specific task associated with this model instance (e.g., ‘causal-language-model’).
- Returns
The model task identifier, or None if not set.
- Return type
tp.Optional[str]
- property model_type: Optional[str]#
Returns the specific type of this model instance (e.g., ‘llama’, ‘mistral’).
- Returns
The model type identifier, or None if not set.
- Return type
tp.Optional[str]
- property module_dtype: dtype#
Determines the data type of the module’s parameters.
It inspects the flattened parameter state to find the dtype of the first parameter encountered.
- Returns
The data type of the module’s parameters.
- Return type
jnp.dtype
- property parameters: Dict#
Retrieves the parameters of the module as a dictionary.
This property iterates through the module and its submodules, extracting variables marked as nn.Param and returning them in a flat dictionary where keys represent the parameter path.
- Returns
A dictionary containing the module’s parameters.
- Return type
tp.Dict
- property params: Dict#
Returns the parameters and other state variables of the module as a dictionary.
Uses flax.nnx.split to get the combined state (parameters and others).
- Returns
A dictionary containing all state variables of the module.
- Return type
tp.Dict
- property params_sharding: Dict#
Retrieves the sharding annotation for each parameter in the module.
- Returns
- A nested dictionary mirroring the parameter structure, containing the
sharding information (e.g., NamedSharding, PartitionSpec) for each parameter, or None if unsharded.
- Return type
tp.Dict
- prepare_inputs_for_call(**kwargs)[source]#
Prepares keyword arguments before passing them to the module’s __call__ method.
This base implementation simply returns the kwargs as is. Subclasses can override this to modify or add arguments as needed (e.g., for generation).
- Parameters
**kwargs – The keyword arguments intended for __call__.
- Returns
The prepared keyword arguments.
- Return type
dict
- property pure_transform_fn#
Returns a pure transformation function for PyTorch state dicts to EasyDeL parameters.
Similar to transform_fn, but this version does not include sharding functions. It identifies embedding and LayerNorm layers and returns a partial function (torch_dict_to_easydel_params) configured only with layer names and dtype.
- Returns
A partial function for converting a PyTorch state dict without applying sharding.
- Return type
tp.Callable
- quantize(method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT, block_size: int = 128, quantization_pattern: Optional[str] = None, quantize_tensors: bool = True, verbose: Optional[bool] = None) SELF[source]#
Applies quantization to the module’s linear layers or tensors.
- Parameters
method (EasyDeLQuantizationMethods, optional) – The quantization algorithm to use (e.g., A8BIT, NF4). Defaults to EasyDeLQuantizationMethods.A8BIT.
block_size (int, optional) – The block size for quantization methods that support it. Defaults to 128.
quantization_pattern (tp.Optional[str], optional) – A regular expression to match parameter names that should be quantized. If None, uses a default pattern. Defaults to None.
quantize_tensors (bool, optional) – If True, quantizes the tensor values directly. If False (currently default behavior in implementation), replaces Linear layers with their quantized equivalents. Defaults to True (though implementation differs).
verbose (tp.Optional[bool], optional) – If True, logs information during the quantization process. Defaults to True only on process index 0.
- Returns
The quantized model instance.
- Return type
SELF
- shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#
Shards the model’s parameters according to the specified rules and mesh.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses config rules. Defaults to None.
mesh (tp.Optional[Mesh], optional) – JAX device mesh. If None, uses config mesh. Defaults to None.
overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]], optional) – Additional functions to apply, potentially overriding default sharding for specific parameters. Defaults to None.
- Returns
The sharded model instance.
- Return type
SELF
- split_lora_params() Dict[source]#
Splits merged LoRA parameters back out from the base model’s parameters.
This function assumes LoRA parameters were previously merged using merge_lora_params or a similar process that stored the original base weights and LoRA weights appropriately.
- Returns
- A pytree containing the extracted LoRA parameters (A and B matrices).
The base model parameters are restored to their original (pre-merge) state.
- Return type
tp.Dict
- split_params()[source]#
Splits the module and returns the parameter state.
Uses nnx.split to extract the GraphState containing the parameters.
- Returns
The parameter state of the module.
- Return type
nn.GraphState
- split_params_dict(extract_fn: Optional[Callable] = None, remove_none: bool = True) Dict[source]#
Splits the module parameters and returns them as a nested dictionary.
Extracts the parameter state, converts it to a plain dictionary (removing VariableState wrappers), and optionally removes entries with None values.
- Parameters
extract_fn (tp.Optional[tp.Callable], optional) – A function to apply to each parameter during extraction. Defaults to None.
remove_none (bool, optional) – If True, removes key-value pairs where the value is None. Defaults to True.
- Returns
A nested dictionary containing the module’s parameters.
- Return type
tp.Dict
- property static_arguments: Tuple#
Retrieves or computes static arguments needed for the module’s __call__ method.
Uses self.get_static_arguments() and caches the result. Static arguments are typically those that don’t change during execution and can be pre-computed.
- Returns
A tuple of static arguments.
- Return type
tp.Tuple
- to_dtype(dtype: dtype) SELF[source]#
Converts the module’s parameters to the specified data type.
It iterates through the module’s parameters (excluding quantization-related ones) and casts them to the target dtype. It also updates the param_dtype attribute of the module and its submodules if they exist.
- Parameters
dtype (jnp.dtype) – The target data type for the parameters.
- Returns
The module instance with parameters converted to the specified dtype.
- Return type
SELF
- to_state() Any[source]#
Converts the current module instance into an EasyDeLState object.
This is useful for saving and managing the model’s state, including parameters and potentially optimizer state (though optimizer state is typically added later).
- Returns
An EasyDeLState object representing the current model state.
- Return type
- to_torch(**kwargs)[source]#
Converts the EasyDeL module to its equivalent Hugging Face PyTorch model.
Requires the corresponding PyTorch model class to be available and registered. Uses utility functions to transfer parameters from JAX to PyTorch format.
- Parameters
**kwargs – Additional keyword arguments passed to the parameter transformation function.
- Returns
The equivalent Hugging Face PyTorch model with loaded weights.
- Return type
torch.nn.Module
- property transform_fn#
Returns a partial function for transforming PyTorch state dicts to EasyDeL parameters.
This function identifies embedding and LayerNorm layers within the module and creates a transformation function (torch_dict_to_easydel_params) pre-configured with these layer names, the target parameter dtype, and the module’s sharding functions.
- Returns
A partial function ready to convert a PyTorch state dict.
- Return type
tp.Callable
- unwrap_lora_to_layers(verbose: bool = False) SELF[source]#
Reverts the application of LoRA layers, restoring the original linear layers.
Replaces easydel.layers.lora.LoraLinear layers with their original flax.linen.Dense counterparts, discarding the LoRA matrices.
- Parameters
verbose (bool, optional) – If True, prints information about which layers are being reverted. Defaults to False.
- Returns
The module instance with LoRA layers removed and original layers restored.
- Return type
SELF
- class easydel.__init__.EasyDeLGradientCheckPointers(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,EnumEnumeration of gradient checkpointing strategies available in EasyDeL.
Gradient checkpointing is a technique to reduce memory usage during training by recomputing activations during the backward pass instead of storing them.
- EVERYTHING_SAVEABLE#
Checkpoints residuals, attentions, and hidden states. This is the most memory-intensive checkpointing strategy.
- NOTHING_SAVEABLE#
Checkpoints only the residuals. This strategy saves the most memory but requires more recomputation.
- CHECKPOINT_DOTS#
Checkpoints matrix multiplications and intermediate activations.
- CHECKPOINT_DOTS_WITH_NO_BATCH_DMIS#
Similar to CHECKPOINT_DOTS but avoids checkpointing operations involving batch dimensions.
- NONE#
No gradient checkpointing is applied.
- CHECKPOINT_DOTS = 'checkpoint_dots'#
- CHECKPOINT_DOTS_WITH_NO_BATCH_DMIS = 'checkpoint_dots_with_no_batch_dims'#
- EVERYTHING_SAVEABLE = 'everything_saveable'#
- NONE = ''#
- NOTHING_SAVEABLE = 'nothing_saveable'#
- class easydel.__init__.EasyDeLOptimizers(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,EnumEnumeration of available optimizers in the EasyDeL library.
- ADAFACTOR#
Represents the Adafactor optimizer.
- LION#
Represents the Lion optimizer.
- ADAMW#
Represents the AdamW optimizer.
- RMSPROP#
Represents the RMSprop optimizer.
- ADAFACTOR = 'adafactor'#
- ADAMW = 'adamw'#
- LION = 'lion'#
- RMSPROP = 'rmsprop'#
- class easydel.__init__.EasyDeLPlatforms(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,EnumEnumeration of platforms or kernel execution backends supported by EasyDeL.
This allows selecting optimized kernel implementations for different hardware or software environments.
- JAX#
Use standard JAX kernel implementations.
- TRITON#
Use Triton-based kernel implementations (often for GPUs).
- PALLAS#
Use Pallas-based kernel implementations (often for TPUs).
- JAX = 'jax'#
- PALLAS = 'pallas'#
- TRITON = 'triton'#
- class easydel.__init__.EasyDeLQuantizationMethods(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,EnumEnumeration of quantization methods supported by EasyDeL.
Quantization reduces the precision of model weights and/or activations to save memory and potentially speed up inference.
- NONE#
No quantization is applied.
- NF4#
Represents NormalFloat 4-bit quantization.
- A8BIT#
Represents 8-bit affine quantization.
- A8BIT = '8bit'#
- NF4 = 'nf4'#
- NONE = 'None'#
- class easydel.__init__.EasyDeLSchedulers(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,EnumEnumeration of available learning rate schedulers in EasyDeL.
- NONE#
Indicates no scheduler should be used.
- LINEAR#
Represents a linear learning rate decay scheduler.
- COSINE#
Represents a cosine annealing learning rate scheduler.
- COSINE = 'cosine'#
- LINEAR = 'linear'#
- NONE = 'None'#
- class easydel.__init__.EasyDeLState(step: int | jax.Array, graphdef: nn.GraphDef, graphstate: nn.GraphState, graphother: nn.GraphState, tx: optax.GradientTransformation, opt_state: tp.Optional[optax.OptState], apply_fn: tp.Optional[tp.Callable] = None)[source]#
Bases:
PyTreeNodeRepresents the state of an EasyDeL model during training or inference.
This class encapsulates the model’s parameters, optimizer state, training step, and potentially other metadata. It provides methods for applying gradients, managing sharding, saving, and loading the state.
- graphdef#
The definition of the model’s computation graph (structure).
- Type
nn.GraphDef
- graphstate#
The state of the model’s parameters.
- Type
nn.GraphState
- graphother#
The state of non-parameter variables within the model.
- Type
nn.GraphState
- tx#
The optimizer transformation (e.g., AdamW, SGD). Marked as a non-pytree node.
- Type
optax.GradientTransformation
- opt_state#
The state of the optimizer (e.g., moments). Marked as a pytree node.
- Type
tp.Optional[optax.OptState]
- apply_fn#
A function to apply the model (often model.__call__). Typically not directly part of the state but can be associated.
- Type
tp.Optional[tp.Callable]
- apply_fn: tp.Optional[tp.Callable] = None#
- apply_gradients(*, grads)[source]#
Updates the model’s parameters and optimizer state based on calculated gradients.
- Parameters
grads – A pytree matching the structure of self.graphstate containing the gradients.
- Returns
- A new state object with the updated parameters (graphstate),
optimizer state (opt_state), and incremented step count.
- Return type
- Raises
AssertionError – If opt_state or tx is not initialized.
- classmethod create(*, step: tp.Optional[int] = None, graphdef: tp.Optional[nn.GraphDef] = None, graphstate: tp.Optional[nn.GraphState] = None, graphother: tp.Optional[nn.GraphState] = None, model: tp.Optional[nn.Module] = None, tx: tp.Optional[optax.GradientTransformation] = None, opt_state: tp.Optional[optax.OptState] = None, init_opt_state: bool = False) EasyDeLState[source]#
Creates a new EasyDeLState instance.
This class method provides a flexible way to initialize the state, either from an existing nn.Module or by providing the graph components (graphdef, graphstate, graphother) directly. It also handles optimizer state initialization.
- Parameters
step (tp.Optional[int]) – The initial training step. Defaults to 0.
graphdef (tp.Optional[nn.GraphDef]) – The model’s graph definition.
graphstate (tp.Optional[nn.GraphState]) – The model’s parameter state.
graphother (tp.Optional[nn.GraphState]) – The model’s non-parameter state.
model (tp.Optional[nn.Module]) – An EasyDeL module instance. If provided, graphdef, graphstate, and graphother are derived from it. Cannot be provided simultaneously with graph components.
tx (tp.Optional[optax.GradientTransformation]) – The optimizer transformation.
opt_state (tp.Optional[optax.OptState]) – The initial optimizer state. Cannot be provided if init_opt_state is True.
init_opt_state (bool) – If True, initializes the optimizer state using tx.init(graphstate). Requires tx to be provided. Defaults to False.
- Returns
A new instance of the state.
- Return type
- Raises
ValueError – If model and graph components are provided simultaneously.
ValueError – If graph components are provided partially.
ValueError – If init_opt_state is True and opt_state is also provided.
ValueError – If init_opt_state is True but tx is not provided.
- gather_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#
Gathers the model parameters (graphstate and graphother) from distributed devices.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules used for the original sharding. If None, uses model config rules. Defaults to None.
mesh (tp.Optional[Mesh], optional) – The JAX device mesh to gather from. If None, uses model’s mesh. Defaults to None.
- Returns
A new state object with gathered graphstate and graphother.
- Return type
- gather_optimizer_state(partition_rules=None)[source]#
Gathers the optimizer state from potentially distributed devices.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules used to determine how the state was sharded. If None, uses rules from the model’s config. Defaults to None.
- Returns
A new state object with the gathered opt_state.
- Return type
- Raises
AssertionError – If opt_state is not initialized.
- gather_state()[source]#
Gathers the entire state (model parameters and optimizer state) from distributed devices.
This is a convenience method that calls gather_model and gather_optimizer_state.
- Returns
A new state object with both model and optimizer states gathered.
- Return type
- graphdef: nn.GraphDef#
- graphother: nn.GraphState#
- graphstate: nn.GraphState#
- init_tx(tx: GradientTransformation, partition_rules: Any = None) EasyDeLState[source]#
Initializes the optimizer state (opt_state) for the current graphstate using the provided optimizer transformation (tx). It automatically handles sharding based on the model’s partition rules.
- Parameters
tx (optax.GradientTransformation) – The optimizer transformation to initialize with.
partition_rules (PartitionLike, optional) – Partitioning rules for the optimizer state. If None, uses the rules from the associated model’s config. Defaults to None.
- Returns
- A new state object with the initialized and potentially sharded
opt_state and the provided tx.
- Return type
- load_optimizer(load_directory: Union[str, PathLike])[source]#
Loads the optimizer state from saved files.
Reads the optimizer state structure from a pickle file (OPTIMIZER_STRUCT_NAME) and the tensor data from a SafeTensors file (OPTIMIZER_NAME) within the specified directory.
- Parameters
load_directory (tp.Union[str, os.PathLike]) – The directory containing the saved optimizer state files.
- Returns
A new state object with the loaded opt_state.
- Return type
- Raises
FileNotFoundError – If the required optimizer files are not found.
Exception – If any error occurs during loading or deserialization.
- merge(tree) Any[source]#
Merges a given state tree (usually parameters) with the graph definition and other state components to reconstruct the full model module.
- Parameters
tree – The pytree (e.g., nn.GraphState) containing the parameters to merge.
- Returns
The reconstructed model module.
- Return type
- merge_to_state(tree) EasyDeLState[source]#
Creates a new EasyDeLState by replacing the current graphstate with the provided tree.
- Parameters
tree – The pytree (e.g., nn.GraphState) containing the new parameters.
- Returns
A new state object with the updated graphstate.
- Return type
- property model: Any#
Reconstructs and returns the full EasyDeL model module from the state components.
- Returns
The model module instance.
- Return type
- opt_state: tp.Optional[optax.OptState]#
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- save_state(save_directory: Union[str, PathLike], float_dtype: Optional[dtype] = None, verbose: bool = True, mismatch_allowed: bool = True, save_optimizer: bool = True, enable: Optional[bool] = None)[source]#
Saves the entire EasyDeLState to a directory.
This includes saving the model parameters (using model.save_pretrained) and optionally the optimizer state.
- Parameters
save_directory (tp.Union[str, os.PathLike]) – The directory to save the state to.
float_dtype (tp.Optional[jax.numpy.dtype]) – Optional dtype to cast floating-point parameters to before saving. Defaults to None.
verbose (bool) – If True, logs information during saving. Defaults to True.
mismatch_allowed (bool) – Passed to model.save_pretrained, allows saving even if the model structure differs slightly from expected. Defaults to True.
save_optimizer (bool) – If True, saves the optimizer state. Defaults to True.
enable (tp.Optional[bool]) – If set, controls whether saving happens (True) or is skipped (False). If None, saving typically occurs only on JAX process index 0. Defaults to None.
- shard_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#
Shards the model parameters (graphstate and graphother) based on partition rules.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses model config rules. Defaults to None.
mesh (tp.Optional[Mesh], optional) – The JAX device mesh to shard across. If None, uses model’s mesh. Defaults to None.
- Returns
A new state object with sharded graphstate and graphother.
- Return type
- shard_optimizer_state(opt_state: Optional[Any] = None, partition_rules: Any = None) Any[source]#
Applies sharding to the optimizer state based on partition rules.
- Parameters
opt_state (tp.Optional[tp.Any]) – The optimizer state pytree to shard. If None, uses self.opt_state. Defaults to None.
partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses rules from the model’s config. Defaults to None.
- Returns
A new state object with the sharded opt_state.
- Return type
- Raises
ValueError – If optimizer state is not initialized (neither opt_state argument nor self.opt_state is available).
- shard_state(partition_rules: Any = None) EasyDeLState[source]#
Shards the entire state (model parameters and optimizer state) based on partition rules.
This is a convenience method that calls shard_model and shard_optimizer_state.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules. If None, uses rules from the model’s config. Defaults to None.
- Returns
A new state object with both model and optimizer states sharded.
- Return type
- shard_with_shape(shape) EasyDeLState[source]#
Applies sharding constraints to the entire state based on a reference shape pytree.
This method takes a pytree shape which has the same structure as the EasyDeLState but contains sharding annotations (e.g., NamedSharding) instead of actual array data. It applies these shardings as constraints to the corresponding arrays in the current state.
- Parameters
shape – A pytree with the same structure as self, containing sharding annotations.
- Returns
A new state object with sharding constraints applied.
- Return type
- property shardings#
Retrieves the sharding annotations (e.g., NamedSharding) for all components of the EasyDeLState pytree.
- Returns
A pytree with the same structure as self, containing sharding annotations or None for components without sharding.
- property size: int#
Calculates the total size in bytes of the model parameters (graphstate) and the optimizer state (opt_state).
- Returns
The total size in bytes.
- Return type
int
- tx: optax.GradientTransformation#
- class easydel.__init__.ExaoneConfig(vocab_size: int = 102400, hidden_size: int = 2048, intermediate_size: int = 14336, num_layers: int = 32, num_attention_heads: int = 32, num_key_value_heads: int = 8, activation_function='silu', max_position_embeddings=2048, initializer_range=0.02, layer_norm_epsilon=1e-05, use_cache=True, embed_dropout: float = 0.0, pad_token_id: Optional[int] = None, bos_token_id: int = 1, eos_token_id: int = 2, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling: Dict[str, Union[str, float]] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, attention_dropout: float = 0.0, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 102400) – Vocabulary size of the Exaone model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 14336) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
head_dim (int, defaults to 128) – Dimensionality of the head for attention.
num_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional, defaults to 8) – Number of key and value heads for each attention layer in the Transformer encoder.
activation_function (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional) – The index of the padding token in the vocabulary.
bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.
eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.
scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.
bits (int, optional) – The number of bits to quantize the model to.
attention_bias (bool, optional, defaults to False) – Whether to use bias in the attention layer.
- attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, attention_dropout: float = 0.0, rope_scaling: Dict[str, Union[str, float]] = None, attention_bias: bool = False, **kwargs)[source]#
The attach_custom_arguments function adds the following arguments to the model:
- Parameters
gradient_checkpointing (str) – Determine whether to use gradient checkpointing
use_scan_mlp (bool) – Determine whether to use the scan_mlp function or notn
scan_mlp_chunk_size (int) – Chunk the input to the mlp
bits (tp.Optional[int]) – Specify the number of bits to use for quantization
attention_dropout (float) – Set the dropout rate for the attention layer
attention_bias (bool) – when ever to use attention_bias
rope_scaling (tp.Dict[str, tp.Union[str, float]]) – rope_scaling for rope
- Return type
A tuple of the following
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- static get_weight_decay_exclusions()[source]#
Returns a tuple of parameter names for which weight decay should be excluded.
- Returns
An empty tuple, indicating no weight decay exclusions.
- Return type
tuple
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size for frequency-based position embeddings.
- Returns
The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size for mask-based position embeddings.
- Returns
The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.
- Return type
int
- model_type: str = 'exaone'#
- class easydel.__init__.ExaoneForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleExaone model with a language modeling head for causal language modeling tasks.
This model extends the base ExaoneModel by adding a linear language modeling head on top of the transformer model. It’s designed for generative tasks and can be used for text generation.
- class easydel.__init__.ExaoneForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.ExaoneModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- class easydel.__init__.FalconConfig(vocab_size=65024, hidden_size=4544, num_hidden_layers=32, num_attention_heads=71, num_ln_in_parallel_attn=None, layer_norm_epsilon=1e-05, initializer_range=0.02, use_cache=True, hidden_dropout=0.0, attention_dropout=0.0, num_kv_heads=None, alibi=False, new_decoder_architecture=False, multi_query=True, parallel_attn=True, bias=False, max_position_embeddings=2048, rope_theta=10000.0, rope_scaling=None, bos_token_id=11, eos_token_id=11, ffn_hidden_size=None, ff_factor=None, activation='gelu', gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 65024) – Vocabulary size of the Falcon model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 4544) – Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 71) – Number of attention heads for each attention layer in the Transformer encoder.
num_ln_in_parallel_attn (int, optional) – The number of layer norms in the parallel attention layer.
layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
hidden_dropout (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
num_kv_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.
alibi (bool, optional) – Whether to use alibi attention.
new_decoder_architecture (bool, optional) – Whether to use the new decoder architecture.
multi_query (bool, optional, defaults to True) – Whether to use multi-query attention.
parallel_attn (bool, optional, defaults to True) – Whether to use parallel attention.
bias (bool, optional, defaults to False) – Whether to use bias in the linear layers.
max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The rope scaling configuration.
bos_token_id (int, optional, defaults to 11) – The index of the beginning of sequence token in the vocabulary.
eos_token_id (int, optional, defaults to 11) – The index of the end of sequence token in the vocabulary.
ffn_hidden_size (int, optional) – Dimensionality of the hidden layer in the FFN
ff_factor (int, optional) – The scaling factor of the FFN
activation (str, optional, defaults to “gelu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
gradient_checkpointing (str, optional, defaults to “”) – The gradient checkpointing configuration.
bits (int, optional) – The number of bits to quantize the model to.
- attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
Attach custom arguments to the configuration.
- Parameters
gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.
bits (int, optional) – Quantization bits. Defaults to None.
**kwargs – Additional keyword arguments.
- Returns
The updated configuration instance.
- Return type
- attribute_map: dict[str, str] = {'num_attention_heads': 'num_attention_heads', 'num_hidden_layers': 'num_hidden_layers'}#
- static get_mesh_names()[source]#
Returns the mesh names used for model parallelism.
- Returns
A tuple containing “dp”, “fsdp”, and “tp” as the mesh names.
- Return type
tuple
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size for frequency-based position embeddings.
- Returns
The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size for mask-based position embeddings.
- Returns
The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.
- Return type
int
- model_type: str = 'falcon'#
- property rotary#
- class easydel.__init__.FalconForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleFalcon model with a language modeling head for causal language modeling tasks.
This model extends the base FalconModel by incorporating a linear language modeling head on top of the base model, designed for generative tasks and text generation. The model can use either alibi positional embeddings or rotary position embeddings (RoPE) based on configuration.
- class easydel.__init__.FalconModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.GPT2Config(vocab_size=50257, n_positions=1024, n_embd=768, n_layer=12, n_head=12, n_inner=None, activation_function='gelu_new', resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, layer_norm_epsilon=1e-05, initializer_range=0.02, summary_type='cls_index', summary_use_proj=True, summary_activation=None, summary_proj_to_labels=True, summary_first_dropout=0.1, scale_attn_weights=True, use_cache=True, bos_token_id=50256, eos_token_id=50256, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, tie_word_embeddings: bool = False, bits: Optional[int] = None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 50257) – Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
n_positions (int, optional, defaults to 1024) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
n_embd (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.
n_layer (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.
n_head (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.
n_inner (int, optional) – Dimensionality of the inner feed-forward layers.
activation_function (str, optional, defaults to “gelu_new”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
resid_pdrop (float, optional, defaults to 0.1) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
embd_pdrop (float, optional, defaults to 0.1) – The dropout ratio for the embeddings.
attn_pdrop (float, optional, defaults to 0.1) – The dropout ratio for the attention probabilities.
layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon to use in the layer normalization layers.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
summary_type (str, optional, defaults to “cls_index”) – The summary type to use. Possible values are “cls_index” (equivalent to the output of the last token of the first sentence in a sequence) and “last” (equivalent to the output of the last token in the sequence).
summary_use_proj (bool, optional, defaults to True) – Whether to use a projection after the vector extraction.
summary_activation (str, optional) – The activation to use for the summary.
summary_proj_to_labels (bool, optional, defaults to True) – Whether to project the summary to the labels.
summary_first_dropout (float, optional, defaults to 0.1) – The dropout ratio to be used after the projection and activation.
scale_attn_weights (bool, optional, defaults to True) – Scale attention weights by dividing by sqrt(hidden_size).
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
bos_token_id (int, optional, defaults to 50256) – The id of the beginning-of-sequence token.
eos_token_id (int, optional, defaults to 50256) – The id of the end-of-sequence token.
scale_attn_by_inverse_layer_idx (bool, optional, defaults to False) – Whether to scale attention weights by 1 / layer_idx + 1.
reorder_and_upcast_attn (bool, optional, defaults to False) – Whether to reorder and upcast attention.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.
bits (int, optional) – The number of bits to quantize the model to.
- attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
Attaches custom arguments to the configuration object.
This method allows adding or overriding configuration attributes dynamically. It iterates through the provided arguments and sets them as attributes of the configuration object if they don’t already exist.
- Parameters
gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.
bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.
**kwargs – Additional keyword arguments to attach to the configuration.
- attribute_map: dict[str, str] = {'hidden_size': 'n_embd', 'max_position_embeddings': 'n_positions', 'num_attention_heads': 'n_head', 'num_hidden_layers': 'n_layer'}#
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.
- Parameters
*args – Additional positional arguments (unused).
**kwargs – Additional keyword arguments (unused).
- Returns
- A tuple of partition rules, where each rule is a tuple
containing a regex pattern for parameter names and the corresponding PartitionSpec.
- Return type
tp.Tuple[tp.Tuple[str, jax.sharding.PartitionSpec]]
- keys_to_ignore_at_inference = ['past_key_values']#
- model_type: str = 'gpt2'#
- class easydel.__init__.GPT2LMHeadModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGPT-2 model with a language modeling head.
This model extends the base GPT2Model by adding a linear layer on top to predict the next token in a sequence, making it suitable for causal language modeling tasks.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.__init__.GPT2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGPT-2 model implementation.
This class implements the main GPT-2 transformer model architecture, consisting of embedding layers (token and position), multiple GPT2Block layers, and a final layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.__init__.GPTJConfig(vocab_size: int = 50400, n_positions: int = 2048, n_embd: int = 4096, n_layer: int = 28, n_head: int = 16, rotary_dim: int = 64, n_inner: int = None, activation_function: str = 'gelu_new', resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, attn_pdrop: float = 0.0, layer_norm_epsilon: float = 1e-05, initializer_range: int = 0.02, use_cache: int = True, bos_token_id: int = 50256, eos_token_id: int = 50256, tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 50400) – Vocabulary size of the GPT-J model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
n_positions (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
n_embd (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.
n_layer (int, optional, defaults to 28) – Number of hidden layers in the Transformer encoder.
n_head (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.
rotary_dim (int, optional, defaults to 64) – The dimension of the rotary position embedding.
n_inner (int, optional) – Dimensionality of the inner feed-forward layers.
activation_function (str, optional, defaults to “gelu_new”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
embd_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.
attn_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon to use in the layer normalization layers.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
bos_token_id (int, optional, defaults to 50256) – The id of the beginning-of-sequence token.
eos_token_id (int, optional, defaults to 50256) – The id of the end-of-sequence token.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.
gradient_checkpointing (str, optional, defaults to “”) – The gradient checkpointing configuration.
bits (int, optional) – The number of bits to quantize the model to.
- attach_custom_arguments(vocab_size: int = 50400, n_positions: int = 2048, n_embd: int = 4096, n_layer: int = 28, n_head: int = 16, rotary_dim: int = 64, n_inner: int = None, activation_function: str = 'gelu_new', resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, attn_pdrop: float = 0.0, layer_norm_epsilon: float = 1e-05, initializer_range: int = 0.02, use_cache: int = True, bos_token_id: int = 50256, eos_token_id: int = 50256, tie_word_embeddings: bool = False, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
Attaches custom arguments to the configuration object.
This method allows adding or overriding configuration attributes dynamically. It iterates through a dictionary of basic configuration parameters and sets them as attributes of the configuration object if they don’t already exist.
- Parameters
vocab_size (int, optional) – Vocabulary size. Defaults to 50400.
n_positions (int, optional) – Maximum sequence length. Defaults to 2048.
n_embd (int, optional) – Hidden size. Defaults to 4096.
n_layer (int, optional) – Number of hidden layers. Defaults to 28.
n_head (int, optional) – Number of attention heads. Defaults to 16.
rotary_dim (int, optional) – Dimension for rotary position embeddings. Defaults to 64.
n_inner (int, optional) – Inner dimension of FFN. Defaults to None.
activation_function (str, optional) – Activation function. Defaults to “gelu_new”.
resid_pdrop (float, optional) – Residual dropout probability. Defaults to 0.0.
embd_pdrop (float, optional) – Embedding dropout probability. Defaults to 0.0.
attn_pdrop (float, optional) – Attention dropout probability. Defaults to 0.0.
layer_norm_epsilon (float, optional) – Epsilon for layer normalization. Defaults to 1e-5.
initializer_range (int, optional) – Initializer range. Defaults to 0.02.
use_cache (int, optional) – Whether to use KV cache. Defaults to True.
bos_token_id (int, optional) – Beginning-of-sequence token ID. Defaults to 50256.
eos_token_id (int, optional) – End-of-sequence token ID. Defaults to 50256.
tie_word_embeddings (bool, optional) – Whether to tie input/output embeddings. Defaults to False.
bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.
gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.
**kwargs – Additional keyword arguments.
- Returns
The configuration object itself (self).
- Return type
- attribute_map: dict[str, str] = {'hidden_size': 'n_embd', 'max_position_embeddings': 'n_positions', 'num_attention_heads': 'n_head', 'num_hidden_layers': 'n_layer'}#
- static get_mesh_names()[source]#
Returns the mesh names used for model parallelism. For GPT-J, it returns mesh names for data parallelism (‘dp’), fully sharded data parallelism (‘fsdp’), sequence parallelism (‘sp’), and tensor parallelism (‘tp’).
- Returns
A tuple containing the mesh names (“dp”, “fsdp”, “sp”, “sp”).
- Return type
tuple
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.
- Parameters
*args – Additional positional arguments (unused).
**kwargs – Additional keyword arguments (unused).
- Returns
- A tuple of partition rules, where each rule is a tuple
containing a regex pattern for parameter names and the corresponding PartitionSpec.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- model_type: str = 'gptj'#
- class easydel.__init__.GPTJForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGPT-J model with a language modeling head.
This model extends the base GPTJModel by adding a linear layer on top to predict the next token in a sequence, making it suitable for causal language modeling tasks.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.__init__.GPTJModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGPT-J model implementation.
This class implements the main GPT-J transformer model architecture, consisting of an embedding layer, multiple GPTJBlock layers, and a final layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- class easydel.__init__.GPTNeoXConfig(vocab_size=50432, hidden_size=6144, num_hidden_layers=44, num_attention_heads=64, intermediate_size=24576, hidden_act='gelu', rotary_pct=0.25, rotary_emb_base=10000, attention_dropout=0.0, hidden_dropout=0.0, classifier_dropout=0.1, max_position_embeddings=2048, initializer_range=0.02, layer_norm_eps=1e-05, use_cache=True, bos_token_id=0, eos_token_id=2, tie_word_embeddings=False, use_parallel_residual=True, rope_scaling=None, attention_bias=True, gradient_checkpointing=EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 50432) – Vocabulary size of the GPT NeoX model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 6144) – Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (int, optional, defaults to 44) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 64) – Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (int, optional, defaults to 24576) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (str or function, optional, defaults to “gelu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
rotary_pct (float, optional, defaults to 0.25) – The percentage of hidden dimensions to allocate to rotary embeddings.
rotary_emb_base (int, optional, defaults to 10000) – The base for the rotary position embedding.
classifier_dropout (float, optional, defaults to 0.1) – The dropout ratio for the classifier layer.
max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
bos_token_id (int, optional, defaults to 0) – The id of the beginning-of-sequence token.
eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.
gradient_checkpointing (str, optional, defaults to “everything_saveable”) – The gradient checkpointing configuration.
use_parallel_residual (bool, optional, defaults to True) – Whether to use a parallel residual connection in the attention layer.
- attach_custom_arguments(**kwargs)[source]#
Attaches custom arguments to the configuration object.
This method allows adding or overriding configuration attributes dynamically. It primarily sets the from_pt attribute to False and ignores other keyword arguments.
- Parameters
**kwargs – Additional keyword arguments (ignored).
- static get_mesh_names()[source]#
Returns the mesh names used for model parallelism. For GPT-NeoX, it returns mesh names for data parallelism (‘dp’), fully sharded data parallelism (‘fsdp’), tensor parallelism (‘tp’), and sequence parallelism (‘sp’).
- Returns
A tuple containing the mesh names (“dp”, “fsdp”, “tp”, “sp”).
- Return type
tuple
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.
- Parameters
*args – Additional positional arguments (unused).
**kwargs – Additional keyword arguments (unused).
- Returns
- A tuple of partition rules, where each rule is a tuple
containing a regex pattern for parameter names and the corresponding PartitionSpec.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- model_type: str = 'gpt_neox'#
- class easydel.__init__.GPTNeoXForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGPT-NeoX model with a language modeling head.
This model extends the base GPTNeoXModel by adding a linear layer on top to predict the next token in a sequence, making it suitable for causal language modeling tasks.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.__init__.GPTNeoXModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGPT-NeoX model implementation.
This class implements the main GPT-NeoX transformer model architecture, consisting of an embedding layer, multiple GPTNeoXBlock layers, and a final layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- class easydel.__init__.GRPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 1e-06, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'GRPOTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: ~typing.Optional[bool] = False, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, max_prompt_length: int = 512, max_completion_length: int = 256, dataset_num_proc: ~typing.Optional[int] = None, beta: float = 0.04, sync_ref_model: bool = False, ref_model_mixup_alpha: float = 0.9, ref_model_sync_steps: int = 64, tools: ~typing.Optional[~typing.List[~typing.Union[dict, ~typing.Callable]]] = None, skip_apply_chat_template: bool = False)[source]#
Bases:
TrainingArgumentsConfiguration class for the GRPOTrainer.
- beta: float = 0.04#
- dataset_num_proc: Optional[int] = None#
- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- learning_rate: float = 1e-06#
- max_completion_length: int = 256#
- max_prompt_length: int = 512#
- model_name: str = 'GRPOTrainer'#
- ref_model_mixup_alpha: float = 0.9#
- ref_model_sync_steps: int = 64#
- remove_unused_columns: Optional[bool] = False#
- replace(**kwargs)#
- skip_apply_chat_template: bool = False#
- sync_ref_model: bool = False#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.
- tools: Optional[List[Union[dict, Callable]]] = None#
- class easydel.__init__.GRPOTrainer(arguments: GRPOConfig, vinference: vInference, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]], reward_funcs: Union[EasyDeLBaseModule, EasyDeLState, Callable[[list, list], list[float]], list[Union[easydel.infra.base_module.EasyDeLBaseModule, easydel.infra.base_state.EasyDeLState, Callable[[list, list], list[float]]]]], train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, processing_class: Optional[Any] = None, reward_processing_classes: Optional[Any] = None, data_tokenize_fn: Optional[Callable] = None)[source]#
Bases:
Trainer- arguments: GRPOConfig#
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- This method sets up the necessary functions for training and evaluation, including:
Initialization of the model state.
Sharding of the model parameters and optimizer state.
JIT-compilation of the training and evaluation step functions.
- Returns
An object containing the configured functions and other relevant information.
- Return type
- on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#
hook process to call in start of the step.
- class easydel.__init__.Gemma2Config(vocab_size=256000, hidden_size=3072, intermediate_size=24576, num_hidden_layers=28, num_attention_heads=16, num_key_value_heads=16, head_dim=256, hidden_activation='gelu_pytorch_tanh', max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, tie_word_embeddings=True, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, final_logit_softcapping=30.0, query_pre_attn_scalar=224, sliding_window=4096, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, scan_layers: bool = False, attn_logit_softcapping: Optional[bool] = None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 256000) – Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 3072) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 24576) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 28) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional, defaults to 16) – Number of key and value heads for each attention layer in the Transformer encoder.
head_dim (int, optional, defaults to 256) – Dimensionality of the attention head.
hidden_activation (str or function, optional, defaults to “gelu_pytorch_tanh”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
max_position_embeddings (int, optional, defaults to 8192) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.
eos_token_id (int, optional, defaults to 1) – The index of the end of sequence token in the vocabulary.
bos_token_id (int, optional, defaults to 2) – The index of the beginning of sequence token in the vocabulary.
tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
attention_bias (bool, optional, defaults to False) – Whether to use attention bias.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
final_logit_softcapping (float, optional, defaults to 30.0) – The soft capping value for the final logits.
query_pre_attn_scalar (int, optional, defaults to 224) – The scalar value for the query pre-attention layer.
sliding_window (int, optional, defaults to 4096) – The sliding window size.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
bits (int, optional) – The number of bits to quantize the model to.
scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation of the layers.
- attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
The attach_custom_arguments function adds the following arguments to the Transformer class:
- Parameters
self – Refer to the current object
gradient_checkpointing – str: Control the amount of memory used by jax
bits – tp.Optional[int]: Determine the number of bits used in the quantization
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- static get_weight_decay_exclusions()[source]#
Returns a tuple of parameter names for which weight decay should be excluded.
- Returns
An empty tuple, indicating no weight decay exclusions.
- Return type
tuple
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size for frequency-based position embeddings.
- Returns
The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size for mask-based position embeddings.
- Returns
The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.
- Return type
int
- model_type: str = 'gemma2'#
- class easydel.__init__.Gemma2ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGemma2 model with a language modeling head for causal language modeling tasks.
This model extends the base Gemma2Model by incorporating a linear language modeling head on top of the base model, designed for generative tasks and text generation.
- class easydel.__init__.Gemma2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.Gemma2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.Gemma3Config(text_config: Optional[Gemma3TextConfig] = None, vision_config: Optional[SiglipVisionConfig] = None, mm_tokens_per_image: int = 256, boi_token_index: int = 255999, eoi_token_index: int = 256000, image_token_index: int = 262144, initializer_range: float = 0.02, **kwargs)[source]#
Bases:
EasyDeLBaseConfig- Parameters
text_config (Union[Gemma3TextConfig, dict], optional) – The config object of the text backbone.
vision_config (Union[AutoConfig, dict], optional) – Custom vision config or dict.
mm_tokens_per_image (int, optional, defaults to 256) – The number of tokens per image embedding.
boi_token_index (int, optional, defaults to 255999) – The begin-of-image token index to wrap the image prompt.
eoi_token_index (int, optional, defaults to 256000) – The end-of-image token index to wrap the image prompt.
image_token_index (int, optional, defaults to 262144) – The image token index to encode the image prompt.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
>>> # Initializing a Siglip-like vision config >>> vision_config = SiglipVisionConfig()
>>> # Initializing a Gemma3 Text config >>> text_config = Gemma3TextConfig()
>>> # Initializing a Gemma3 gemma-3-4b style configuration >>> configuration = Gemma3Config(vision_config, text_config)
>>> # Initializing a model from the gemma-3-4b style configuration >>> model = Gemma3TextConfig(configuration)
>>> # Accessing the model configuration >>> configuration = model.config ```
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model.
- Returns
A tuple of tuples, where each inner tuple contains a regex pattern matching parameter names and the corresponding PartitionSpec for sharding those parameters across devices.
- Return type
Tuple[Tuple[str, PartitionSpec]]
- model_type: str = 'gemma3'#
- sub_configs: dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.gemma3.gemma3_configuration.Gemma3TextConfig'>, 'vision_config': <class 'easydel.__init__.modules.siglip.configuration_siglip.SiglipVisionConfig'>}#
- class easydel.__init__.Gemma3ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGemma3 model with a language modeling head for causal language modeling tasks.
This model extends the base Gemma3TextModel by incorporating a linear language modeling head on top of the base model, designed for generative tasks and text generation. The model can optionally apply softcapping to logits based on configuration settings.
- class easydel.__init__.Gemma3ForConditionalGeneration(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- get_image_features(pixel_values: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#
- loss_type = 'ForCausalLM'#
- prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.
- Parameters
self – Access variables that belong to the class
input_ids – Pass in the input tokens
max_length – Set the length of the sequence to be generated
attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds
- Returns
A dictionary of the past_key_values, attention_mask and position ids
- class easydel.__init__.Gemma3ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.Gemma3TextConfig(vocab_size=262208, hidden_size=2304, intermediate_size=9216, num_hidden_layers=26, num_attention_heads=8, num_key_value_heads=4, head_dim=256, hidden_activation='gelu_pytorch_tanh', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, tie_word_embeddings=True, rope_theta=1000000.0, attention_bias=False, attention_dropout=0.0, query_pre_attn_scalar=256, sliding_window=4096, final_logit_softcapping=None, attn_logit_softcapping=None, cache_implementation='hybrid', rope_scaling=None, rope_local_base_freq=10000.0, sliding_window_pattern=6, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, scan_layers: bool = False, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information. :param vocab_size: Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the
inputs_ids passed when calling [Gemma3TextModel]
- Parameters
hidden_size (int, optional, defaults to 2304) – Dimension of the hidden representations.
intermediate_size (int, optional, defaults to 9216) – Dimension of the MLP representations.
num_hidden_layers (int, optional, defaults to 26) – Number of hidden layers in the Transformer decoder.
num_attention_heads (int, optional, defaults to 8) – Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (int, optional, defaults to 4) – This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to num_attention_heads.
head_dim (int, optional, defaults to 256) – The attention head dimension.
hidden_activation (str or function, optional, defaults to “gelu_pytorch_tanh”) – The non-linear activation function (function or string) in the decoder. Will default to “gelu_pytorch_tanh” if not specified. “gelu_pytorch_tanh” uses an approximation of the “gelu” activation function.
max_position_embeddings (int, optional, defaults to 131072) – The maximum sequence length that this model might ever be used with.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (float, optional, defaults to 1e-06) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional, defaults to 0) – Padding token id.
eos_token_id (int, optional, defaults to 1) – End of stream token id.
bos_token_id (int, optional, defaults to 2) – Beginning of stream token id.
tie_word_embeddings (bool, optional, defaults to True) – Whether to tie weight embeddings
rope_theta (float, optional, defaults to 1000000.0) – The base period of the RoPE embeddings.
attention_bias (bool, defaults to False, optional, defaults to False) – Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
query_pre_attn_scalar (float, optional, defaults to 256) – Scaling factor used on the attention scores
sliding_window (int, optional, defaults to 4096) – in Gemma3Text, every other layer uses sliding window attention. This is the size of the sliding window.
final_logit_softcapping (float, optional) – Scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (float, optional) – Scaling factor when applying tanh softcapping on the attention scores.
cache_implementation (str, optional, defaults to “hybrid”) – the cache type to be used with generate.
rope_scaling (Dict, optional) –
Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type and you expect the model to work on longer max_position_embeddings, we recommend you to update this value accordingly. Expected contents:
- rope_type (str):
The sub-variant of RoPE to use. Can be one of [‘default’, ‘linear’, ‘dynamic’, ‘yarn’, ‘longrope’, ‘llama3’], with ‘default’ being the original RoPE implementation.
- factor (float, optional):
Used with all rope types except ‘default’. The scaling factor to apply to the RoPE embeddings. In most scaling types, a factor of x will enable the model to handle sequences of length x * original maximum pre-trained length.
- original_max_position_embeddings (int, optional):
Used with ‘dynamic’, ‘longrope’ and ‘llama3’. The original max position embeddings used during pretraining.
- attention_factor (float, optional):
Used with ‘yarn’ and ‘longrope’. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the factor field to infer the suggested value.
- beta_fast (float, optional):
Only used with ‘yarn’. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32.
- beta_slow (float, optional):
Only used with ‘yarn’. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1.
- short_factor (List[float], optional):
Only used with ‘longrope’. The scaling factor to be applied to short contexts (< original_max_position_embeddings). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2
- long_factor (List[float], optional):
Only used with ‘longrope’. The scaling factor to be applied to long contexts (< original_max_position_embeddings). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2
- low_freq_factor (float, optional):
Only used with ‘llama3’. Scaling factor applied to low frequency components of the RoPE
- high_freq_factor (float, optional):
Only used with ‘llama3’. Scaling factor applied to high frequency components of the RoPE
rope_local_base_freq (float, optional, defaults to 10000.0) – The base period of the RoPE embeddings for local attention.
sliding_window_pattern – Pattern for the sliding window attention.
- attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
The attach_custom_arguments function adds the following arguments to the Transformer class:
- Parameters
self – Refer to the current object
gradient_checkpointing – str: Control the amount of memory used by jax
bits – tp.Optional[int]: Determine the number of bits used in the quantization
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
- property granted_mask_max_position_embedding: int#
- model_type: str = 'gemma3_text'#
- class easydel.__init__.Gemma3TextModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- property default_frequencies#
- class easydel.__init__.GemmaConfig(vocab_size=256000, hidden_size=3072, intermediate_size=24576, num_hidden_layers=28, num_attention_heads=16, num_key_value_heads=16, head_dim=256, hidden_act='gelu_pytorch_tanh', max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, tie_word_embeddings=True, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, scan_layers: bool = False, hidden_activation='gelu_pytorch_tanh', **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 256000) – Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 3072) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 24576) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 28) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional, defaults to 16) – Number of key and value heads for each attention layer in the Transformer encoder.
head_dim (int, optional, defaults to 256) – Dimensionality of the attention head.
hidden_act (str or function, optional, defaults to “gelu_pytorch_tanh”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
max_position_embeddings (int, optional, defaults to 8192) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.
eos_token_id (int, optional, defaults to 1) – The index of the end of sequence token in the vocabulary.
bos_token_id (int, optional, defaults to 2) – The index of the beginning of sequence token in the vocabulary.
tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
attention_bias (bool, optional, defaults to False) – Whether to use attention bias.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
bits (int, optional) – The number of bits to quantize the model to.
scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation of the layers.
hidden_activation (str, optional) – The hidden activation function to use.
- attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
The attach_custom_arguments function adds the following arguments to the Transformer class:
- Parameters
self – Refer to the current object
gradient_checkpointing – str: Control the amount of memory used by jax
bits – tp.Optional[int]: Determine the number of bits used in the quantization
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- static get_weight_decay_exclusions()[source]#
Returns a tuple of parameter names for which weight decay should be excluded.
- Returns
An empty tuple, indicating no weight decay exclusions.
- Return type
tuple
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size for frequency-based position embeddings.
- Returns
The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size for mask-based position embeddings.
- Returns
The maximum position embedding size, falling back to max_position_embeddings if not explicitly set.
- Return type
int
- model_type: str = 'gemma'#
- class easydel.__init__.GemmaForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.GemmaForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.GemmaModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.Grok1Config(vocab_size=32000, hidden_size=4096, intermediate_size=32768, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, attn_output_multiplier=1.0, max_attn_value=1.0, max_position_embeddings=4096, embedding_multiplier_scale: float = 1.0, output_multiplier_scale: float = 1.0, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=True, num_experts_per_tok=2, num_experts=8, output_router_logits=False, router_aux_loss_coef=0.001, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 32000) – Vocabulary size of the Grok-1 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 32768) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional, defaults to 32) – Number of key and value heads for each attention layer in the Transformer encoder.
attn_output_multiplier (float, optional, defaults to 1.0) – The multiplier value applied to the attention output.
max_attn_value (float, optional, defaults to 1.0) – The maximum value of the attention weights.
max_position_embeddings (int, optional, defaults to 4096) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
embedding_multiplier_scale (float, optional, defaults to 1.0) – The scale factor for the embedding layer.
output_multiplier_scale (float, optional, defaults to 1.0) – The scale factor for the output layer.
rms_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional) – The index of the padding token in the vocabulary.
bos_token_id (int, optional, defaults to 1) – The index of the beginning of sequence token in the vocabulary.
eos_token_id (int, optional, defaults to 2) – The index of the end of sequence token in the vocabulary.
tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.
num_experts_per_tok (int, optional, defaults to 2) – The number of experts per token.
num_experts (int, optional, defaults to 8) – The number of experts.
output_router_logits (bool, optional, defaults to False) – Whether to output router logits.
router_aux_loss_coef (float, optional, defaults to 0.001) – The router auxiliary loss coefficient.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
bits (int, optional) – The number of bits to quantize the model to.
- attach_custom_arguments(tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
Attaches custom arguments to the configuration object.
This method allows adding or overriding configuration attributes dynamically. It primarily sets attributes related to word embeddings, gradient checkpointing, and quantization bits.
- Parameters
tie_word_embeddings (bool, optional) – Whether to tie input/output embeddings. Defaults to False.
gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.
bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.
**kwargs – Additional keyword arguments (ignored).
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.
- Parameters
*args – Additional positional arguments (unused).
**kwargs – Additional keyword arguments (unused).
- Returns
- A tuple of partition rules, where each rule is a tuple
containing a regex pattern for parameter names and the corresponding PartitionSpec.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- static get_weight_decay_exclusions()[source]#
Returns a tuple of parameter names for which weight decay should be excluded.
- Returns
An empty tuple, indicating no specific weight decay exclusions for this model.
- Return type
tuple
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size specifically for frequency-based position embeddings.
If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for frequency encoding.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size specifically for mask-based position embeddings.
If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for mask encoding.
- Return type
int
- model_type: str = 'grok-1'#
- class easydel.__init__.Grok1ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGrok-1 model with a language modeling head.
This model extends the base Grok1Model by adding a linear layer on top to predict the next token in a sequence, making it suitable for causal language modeling tasks. It also includes handling for the Mixture of Experts auxiliary loss.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.__init__.Grok1Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGrok-1 model implementation.
This class implements the main Grok-1 transformer model architecture, consisting of an embedding layer, multiple Grok1DecoderLayer layers (with sparse MoE), and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- class easydel.__init__.InternLM2Config(vocab_size=103168, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=0, bos_token_id=1, eos_token_id=2, pretraining_tp=1, tie_word_embeddings=False, bias=True, rope_theta=10000, rope_scaling=None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = -1, fcm_max_ratio: float = -1, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, scan_layers: bool = False, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 32000) – Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 11008) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to number_rep_kv * num_attention_heads if not set.
max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional, defaults to 0) – The id of the pad token.
bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.
eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
bias (bool, optional, defaults to False) – Whether to use attention bias.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
fcm_min_ratio (float, optional, defaults to -1) – The minimum ratio for Flash Attention.
fcm_max_ratio (float, optional, defaults to -1) – The maximum ratio for Flash Attention.
rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.
scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.
bits (int, optional) – The number of bits to quantize the model to.
hidden_act (str, optional, defaults to “silu”) – The hidden activation function to use.
pretraining_tp (int, optional, defaults to 1) – The tensor parallelism degree used during pretraining.
mlp_bias (bool, optional, defaults to False) – Whether to use bias in the MLP.
scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation for the layers.
- attach_custom_arguments(tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = 0.0, fcm_max_ratio: float = 0.0, bits: Optional[int] = None, rope_theta: float = 10000.0, hidden_act: str = 'silu', scan_layers: bool = True, **kwargs)[source]#
Attaches custom arguments to the configuration object.
This method allows adding or overriding configuration attributes dynamically. It iterates through the provided arguments and sets them as attributes of the configuration object.
- Parameters
tie_word_embeddings (bool, optional) – Whether to tie input/output embeddings. Defaults to False.
gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.
fcm_min_ratio (float, optional) – Minimum ratio for Flash Attention. Defaults to 0.0.
fcm_max_ratio (float, optional) – Maximum ratio for Flash Attention. Defaults to 0.0.
bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.
rope_theta (float, optional) – Base value for RoPE. Defaults to 10000.0.
hidden_act (str, optional) – Activation function. Defaults to “silu”.
scan_layers (bool, optional) – Whether to use scan layers. Defaults to True.
**kwargs – Additional keyword arguments (ignored in this implementation).
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.
- Parameters
*args – Additional positional arguments (unused).
**kwargs – Additional keyword arguments (unused).
- Returns
- A tuple of partition rules, where each rule is a tuple
containing a regex pattern for parameter names and the corresponding PartitionSpec.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- static get_weight_decay_exclusions()[source]#
Returns a tuple of parameter names for which weight decay should be excluded.
- Returns
An empty tuple, indicating no specific weight decay exclusions for this model.
- Return type
tuple
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size specifically for frequency-based position embeddings.
If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for frequency encoding.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size specifically for mask-based position embeddings.
If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for mask encoding.
- Return type
int
- model_type: str = 'internlm2'#
- class easydel.__init__.InternLM2ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleInternLM2 model with a Causal Language Modeling head.
This model consists of the base InternLM2 transformer (InternLM2Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation. Default is jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Default is jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Default is None.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- module#
The core InternLM2 transformer model.
- Type
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.__init__.InternLM2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleInternLM2 model with a Sequence Classification head.
This model consists of the base InternLM2 transformer (InternLM2Model) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the first token) to the number of classes for classification.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation. Default is jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Default is jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Default is None.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- module#
The core InternLM2 transformer model.
- Type
- score#
The linear layer for classification.
- Type
- class easydel.__init__.InternLM2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base InternLM2 model transformer.
This class represents the core transformer architecture of the InternLM2 model, consisting of embedding layers, multiple transformer blocks, and a final layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation. Default is jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Default is jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Default is None.
- Type
jax.lax.PrecisionLike
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
Sequence of transformer blocks.
- Type
tp.Sequence[InternLM2Block]
- gradient_checkpointing#
Gradient checkpointing configuration.
- scan_layers#
Whether to use JAX scan for layer processing.
- Type
bool
- blocks_class#
The class used for the transformer blocks.
- Type
- class easydel.__init__.JaxDistributedConfig[source]#
Bases:
objectFrom EasyLM Utility class for initializing JAX distributed.
- class easydel.__init__.Llama4Config(vision_config=None, text_config=None, boi_token_index=200080, eoi_token_index=200081, image_token_index=200092, tie_word_embeddings=False, **kwargs)[source]#
Bases:
EasyDeLBaseConfig- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the Llama4 model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- model_type: str = 'llama4'#
- sub_configs: dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.llama4.llama4_configuration.Llama4TextConfig'>, 'vision_config': <class 'easydel.__init__.modules.llama4.llama4_configuration.Llama4VisionConfig'>}#
- class easydel.__init__.Llama4ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.Llama4ForConditionalGeneration(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleLlama4Vision model for conditional text generation based on image inputs. Combines a vision tower and a language model with a multi-modal projector.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- get_image_features(pixel_values: Union[Array, ndarray, bool, number], **kwargs) Union[Array, ndarray, bool, number][source]#
Extracts and projects image features from the vision tower.
- Parameters
pixel_values (chex.Array) – Input pixel values for the images.
- Returns
Processed image features ready for the language model.
- Return type
chex.Array
- loss_type = 'ForCausalLM'#
- prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
Prepares inputs for text generation, including pixel values if provided.
- Parameters
input_ids (chex.Array) – Initial input token IDs.
max_length (int) – Maximum generation length.
pixel_values (Optional[chex.Array]) – Pixel values for image input.
attention_mask (Optional[chex.Array]) – Attention mask.
- Returns
Model inputs ready for generation.
- Return type
dict
- update_inputs_for_generation(model_outputs, model_kwargs)[source]#
Updates model inputs for the next step of generation, removing pixel values after the first step.
- Parameters
model_outputs – Outputs from the previous generation step.
model_kwargs – Current keyword arguments for the model.
- Returns
Updated model keyword arguments.
- Return type
dict
- class easydel.__init__.Llama4ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleLlama model for sequence classification tasks.
This class extends the base Llama model by adding a linear classification head to perform sequence classification tasks such as sentiment analysis or text classification.
- config#
Configuration for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- class easydel.__init__.Llama4TextConfig(vocab_size=202048, hidden_size=5120, intermediate_size=8192, intermediate_size_mlp=16384, num_hidden_layers=48, num_attention_heads=40, num_key_value_heads=8, head_dim=128, hidden_act='silu', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=500000, attention_dropout=0.0, num_experts_per_tok=1, num_local_experts=16, moe_layers=None, interleave_moe_layer_step=1, use_qk_norm=True, output_router_logits=False, router_aux_loss_coef=0.001, router_jitter_noise=0.0, rope_scaling=None, no_rope_layers=None, no_rope_layer_interval=4, attention_chunk_size=8192, attn_temperature_tuning=4, floor_scale=8192, attn_scale=0.1, **kwargs)[source]#
Bases:
EasyDeLBaseConfig- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the Llama4Text model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- model_type: str = 'llama4_text'#
- class easydel.__init__.Llama4TextModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.Llama4VisionConfig(hidden_size: int = 768, hidden_act: str = 'gelu', num_hidden_layers: int = 34, num_attention_heads: int = 16, num_channels: int = 3, intermediate_size: int = 5632, vision_output_dim: int = 7680, image_size: int = 448, patch_size: int = 14, norm_eps: float = 1e-05, vision_feature_layer=-1, vision_feature_select_strategy='default', initializer_range: float = 0.02, pixel_shuffle_ratio=0.5, projector_input_dim=4096, projector_output_dim=4096, multi_modal_projector_bias=False, projector_dropout=0.0, attention_dropout=0.0, rope_theta=10000, **kwargs)[source]#
Bases:
EasyDeLBaseConfig- base_config_key: str = 'vision_config'#
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the Llama4Vision model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- model_type: str = 'llama4_vision_model'#
- class easydel.__init__.Llama4VisionModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- class easydel.__init__.LlamaConfig(vocab_size: int = 32000, hidden_size: int = 4096, intermediate_size: int = 11008, num_hidden_layers: int = 32, num_attention_heads: int = 32, number_rep_kv: int = 1, head_dim: Optional[int] = None, num_key_value_heads: Optional[int] = None, max_position_embeddings: int = 2048, rms_norm_eps: float = 1e-06, initializer_range: float = 0.02, use_cache: bool = True, bos_token_id: int = 0, eos_token_id: int = 1, resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, attention_dropout: float = 0.0, rope_theta: float = 10000.0, attention_bias: bool = False, tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = -1, fcm_max_ratio: float = -1, rope_scaling: Dict[str, Union[str, float]] = None, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, hidden_act: str = 'silu', pretraining_tp: int = 1, mlp_bias: bool = False, scan_layers: bool = False, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 32000) – Vocabulary size of the Llama model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 11008) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.
number_rep_kv (int, optional, defaults to 1) – Number of repetitions for the key and value vectors.
num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to number_rep_kv * num_attention_heads if not set.
max_position_embeddings (int, optional, defaults to 2048) –
The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
- head_dim (int, optional):
head_dim for attention qkv.
rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
bos_token_id (int, optional, defaults to 0) – The id of the beginning-of-sequence token.
eos_token_id (int, optional, defaults to 1) – The id of the end-of-sequence token.
resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
embd_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
attention_bias (bool, optional, defaults to False) – Whether to use attention bias.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
fcm_min_ratio (float, optional, defaults to -1) – The minimum ratio for Flash Attention.
fcm_max_ratio (float, optional, defaults to -1) – The maximum ratio for Flash Attention.
rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.
scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.
bits (int, optional) – The number of bits to quantize the model to.
hidden_act (str, optional, defaults to “silu”) – The hidden activation function to use.
pretraining_tp (int, optional, defaults to 1) – The tensor parallelism degree used during pretraining.
mlp_bias (bool, optional, defaults to False) – Whether to use bias in the MLP.
scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation for the layers.
- attach_custom_arguments(resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, attention_dropout: float = 0.0, tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = 0.0, fcm_max_ratio: float = 0.0, number_rep_kv: int = 1, bits: Optional[int] = None, rope_theta: float = 10000.0, attention_bias: bool = False, hidden_act: str = 'silu', scan_layers: bool = True, **kwargs)[source]#
The attach_custom_arguments function adds the following arguments to the Transformer class:
- Parameters
self – Refer to the current object
resid_pdrop – float: Set the dropout rate for residual connections
embd_pdrop – float: Set the probability of dropping an embedding
attention_dropout – float: Set the probability of dropping out the attention layer
tie_word_embeddings – bool: Tie the word embeddings to the decoder
gradient_checkpointing – str: Control the amount of memory used by jax
fcm_min_ratio – float: Control the minimum ratio of the number of chunks to be used in flash-based computation
fcm_max_ratio – float: Set the maximum ratio of the number of input tokens to output tokens
number_rep_kv – int: Determine how many times the key and value vectors are repeated
bits – tp.Optional[int]: Determine the number of bits used in the quantization
rope_theta – float : rope_theta for compute rope
attention_bias – bool : whenever to use attention bias or no
hidden_act – str : hidden_act for mlp
scan_layers – bool: Determine whether to use scan layers or not
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
- property granted_mask_max_position_embedding: int#
- model_type: str = 'llama'#
- class easydel.__init__.LlamaForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleLlama model with a language modeling head for causal language modeling tasks.
This model is a transformer-based language model with causal attention masks applied to perform autoregressive language generation.
- config#
Configuration for the model.
- Type
- dtype#
Data type for computations (default is jnp.float32).
- Type
jnp.dtype
- param_dtype#
Data type for parameters (default is jnp.float32).
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
tp.Optional[tp.Union[str, jax.lax.Precision]]
- class easydel.__init__.LlamaForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleLlama model for sequence classification tasks.
This class extends the base Llama model by adding a linear classification head to perform sequence classification tasks such as sentiment analysis or text classification.
- config#
Configuration for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- class easydel.__init__.LlamaModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleLlama model implementation.
This implements the Llama language model architecture, utilizing transformer blocks with RMSNorm, rotary position embeddings, and a specific attention mechanism.
- config#
Configuration for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- class easydel.__init__.LlavaConfig(vision_config=None, text_config=None, image_token_index=32000, projector_hidden_act='gelu', vision_feature_select_strategy='default', vision_feature_layer=-2, image_seq_length=576, multimodal_projector_bias=True, **kwargs)[source]#
Bases:
EasyDeLBaseConfigThis is the configuration class to store the configuration of a [LlavaForConditionalGeneration]. It is used to instantiate an Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Llava-9B.
e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)
Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.
- Parameters
vision_config (Union[AutoConfig, dict], optional, defaults to CLIPVisionConfig) – The config object or dictionary of the vision backbone.
text_config (Union[AutoConfig, dict], optional, defaults to LlamaConfig) – The config object or dictionary of the text backbone.
image_token_index (int, optional, defaults to 32000) – The image token index to encode the image prompt.
projector_hidden_act (str, optional, defaults to “gelu”) – The activation function used by the multimodal projector.
vision_feature_select_strategy (str, optional, defaults to “default”) – The feature selection strategy used to select the vision feature from the vision backbone. Can be one of “default” or “full”.
vision_feature_layer (Union[int, List[int]], optional, defaults to -2) – The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features.
image_seq_length (int, optional, defaults to 576) – Sequence length of one image embedding.
multimodal_projector_bias (bool, optional, defaults to True) – Whether to use bias in the multimodal projector.
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for distributed training by combining the partition rules from both the text and vision configurations.
This method retrieves the partition rules from the text_config and vision_config components and combines them to create a comprehensive set of rules for the entire multimodal model.
- Parameters
*args – Variable length argument list to be passed to the text and vision configs.
**kwargs – Arbitrary keyword arguments to be passed to the text and vision configs.
- Returns
A combined tuple of partition rules from both text and vision configurations.
- Return type
tuple
- model_type: str = 'llava'#
- sub_configs: dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.auto.auto_configuration.AutoEasyDeLConfig'>, 'vision_config': <class 'easydel.__init__.modules.auto.auto_configuration.AutoEasyDeLConfig'>}#
- class easydel.__init__.LlavaForConditionalGeneration(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- get_image_features(pixel_values: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#
- loss_type = 'ForCausalLM'#
- prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.
- Parameters
self – Access variables that belong to the class
input_ids – Pass in the input tokens
max_length – Set the length of the sequence to be generated
attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds
- Returns
A dictionary of the past_key_values, attention_mask and position ids
- class easydel.__init__.LossConfig(ignore_index: int = -100, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]] = 'NUM_REAL_TARGET_TOKENS', num_labels: Optional[str] = None, problem_type: Optional[str] = None, divide_weight_sum: bool = False, shift_tokens: bool = True, break_on_nan: bool = True, reduction: Optional[Literal['none', 'mean', 'sum']] = None, num_classification_labels: Optional[int] = None, classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None)[source]#
Bases:
objectConfiguration class for customizing loss computation behavior.
- ignore_index#
Specifies a target value that is ignored and does not contribute to the loss. Defaults to -100.
- Type
int
- label_smoothing#
Amount of label smoothing to apply. 0.0 means no smoothing. Defaults to 0.0.
- Type
float
- z_loss#
Coefficient for the z-loss regularization term, which encourages logits for non-target classes to be small. Defaults to 0.0.
- Type
float
- loss_normalizing_factor#
How to normalize the loss. Can be a constant float/int, a string representation of a SpecialLossNormalizingFactor enum, or the enum itself. Defaults to “NUM_REAL_TARGET_TOKENS”.
- Type
FACTOR_TYPE
- num_labels#
The number of labels for classification tasks. Used in ForSequenceClassificationLoss. Defaults to None.
- Type
tp.Optional[int]
- problem_type#
Specifies the problem type for sequence classification (e.g., “single_label_classification”, “multi_label_classification”). Defaults to None.
- Type
tp.Optional[str]
- divide_weight_sum#
If True, divides the loss by the sum of weights, in addition to the loss_normalizing_factor. Defaults to False.
- Type
bool
- shift_tokens#
If True (typically for Causal LM), shifts the logits and labels so that the model predicts the next token. Defaults to True.
- Type
bool
- break_on_nan#
If True, raises an EasyDeLBreakRequest if a NaN is encountered during loss computation. Defaults to True.
- Type
bool
- reduction#
Specifies the reduction to apply to the loss. If None, the default reduction of the specific loss function is used. Defaults to None.
- Type
tp.Optional[tp.Literal[“none”, “mean”, “sum”]]
- num_classification_labels#
Number of labels specifically for sequence classification. Alias for num_labels. Defaults to None.
- Type
tp.Optional[int]
- classification_problem_type#
Problem type specifically for sequence classification. Alias for problem_type. Defaults to None.
- Type
tp.Optional[tp.Literal[“regression”, “single_label_classification”, “multi_label_classification”]]
- break_on_nan: bool = True#
- classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None#
- divide_weight_sum: bool = False#
- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- ignore_index: int = -100#
- label_smoothing: float = 0.0#
- loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]] = 'NUM_REAL_TARGET_TOKENS'#
- num_classification_labels: Optional[int] = None#
- num_labels: Optional[str] = None#
- problem_type: Optional[str] = None#
- reduction: Optional[Literal['none', 'mean', 'sum']] = None#
- replace(**kwargs)#
- shift_tokens: bool = True#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.
- z_loss: float = 0.0#
- class easydel.__init__.Mamba2Config(num_heads=128, head_dim=64, vocab_size=32768, hidden_size=4096, state_size=128, num_hidden_layers=64, layer_norm_epsilon=1e-05, pad_token_id=1, bos_token_id=0, eos_token_id=2, expand=2, conv_kernel=4, n_groups=8, use_bias=False, use_conv_bias=True, hidden_act='silu', initializer_range=0.1, residual_in_fp32=True, time_step_rank='auto', time_step_min=0.001, time_step_max=0.1, time_step_floor=0.0001, time_step_limit=(0.0, inf), rescale_prenorm_residual=False, use_cache=True, norm_before_gate=True, rms_norm=True, chunk_size=256, tie_word_embeddings=False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 32768) – Vocabulary size of the Mamba2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.
state_size (int, optional, defaults to 128) – State size of the Mamba2 model.
num_hidden_layers (int, optional, defaults to 64) – Number of hidden layers in the Mamba2 encoder.
num_heads (int, optional, defaults to 128) – Number of attention heads for the grouped selective scan.
head_dim (int, optional, defaults to 64) – Dimension of each attention head.
n_groups (int, optional, defaults to 8) – Number of groups for the grouped selective scan.
layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.
pad_token_id (int, optional, defaults to 1) – The index of the padding token in the vocabulary.
bos_token_id (int, optional, defaults to 0) – The id of the beginning-of-sequence token.
eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.
expand (int, optional, defaults to 2) – Expansion factor for the intermediate size.
conv_kernel (int, optional, defaults to 4) – Kernel size of the convolution layer.
use_bias (bool, optional, defaults to False) – Whether to use bias in the linear layers.
use_conv_bias (bool, optional, defaults to True) – Whether to use bias in the convolution layer.
hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
initializer_range (float, optional, defaults to 0.1) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
residual_in_fp32 (bool, optional, defaults to True) – Whether to compute the residual connection in float32.
time_step_rank (str or int, optional, defaults to “auto”) – The rank of the time step embedding. If set to “auto”, the rank is calculated as math.ceil(self.hidden_size / 16).
time_step_min (float, optional, defaults to 0.001) – The minimum value for the time step embedding.
time_step_max (float, optional, defaults to 0.1) – The maximum value for the time step embedding.
time_step_floor (float, optional, defaults to 1e-4) – The floor value for the time step embedding.
time_step_limit (tuple, optional, defaults to (0.0, float(“inf”))) – The minimum and maximum limits for the time step.
rescale_prenorm_residual (bool, optional, defaults to False) – Whether to rescale the pre-norm residual.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
norm_before_gate (bool, optional, defaults to True) – Whether to apply normalization before the gate activation.
rms_norm (bool, optional, defaults to True) – Whether to use root mean square normalization.
chunk_size (int, optional, defaults to 256) – Size of chunks for processing long sequences.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the word embedding weights with the output projection weights.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for distributing the Mamba2 model parameters across multiple devices.
These rules define how parameters should be partitioned when using techniques like Fully Sharded Data Parallelism (FSDP), Sharded Parallelism (SP), and Tensor Parallelism (TP). Each rule consists of a regex pattern matching parameter names and a corresponding PartitionSpec.
- Returns
- A tuple of tuples where each inner tuple contains:
A regex pattern matching parameter names
A PartitionSpec object specifying how to partition matching parameters
- Return type
tuple
- model_type: str = 'mamba2'#
- class easydel.__init__.Mamba2ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- prepare_inputs_for_generation(input_ids, inputs_embeds=None, cache_params: Optional[Mamba2Cache] = None, cache_position: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, **kwargs)[source]#
The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.
- Parameters
self – Access variables that belong to the class
input_ids – Pass in the input tokens
max_length – Set the length of the sequence to be generated
attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds
- Returns
A dictionary of the past_key_values, attention_mask and position ids
- class easydel.__init__.Mamba2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.MambaConfig(vocab_size=50280, hidden_size=768, state_size=16, num_hidden_layers=32, layer_norm_epsilon=1e-05, pad_token_id=0, bos_token_id=0, eos_token_id=0, expand=2, conv_kernel=4, use_bias=False, use_conv_bias=True, hidden_act='silu', initializer_range=0.1, residual_in_fp32=True, time_step_rank='auto', time_step_scale=1.0, time_step_min=0.001, time_step_max=0.1, time_step_init_scheme='random', time_step_floor=0.0001, rescale_prenorm_residual=False, use_cache=True, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_mambapy: bool = False, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 50280) – Vocabulary size of the Mamba model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.
state_size (int, optional, defaults to 16) – State size of the Mamba model.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.
layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.
pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.
bos_token_id (int, optional, defaults to 0) – The id of the beginning-of-sequence token.
eos_token_id (int, optional, defaults to 0) – The id of the end-of-sequence token.
expand (int, optional, defaults to 2) – Expansion factor for the intermediate size.
conv_kernel (int, optional, defaults to 4) – Kernel size of the convolution layer.
use_bias (bool, optional, defaults to False) – Whether to use bias in the linear layers.
use_conv_bias (bool, optional, defaults to True) – Whether to use bias in the convolution layer.
hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
initializer_range (float, optional, defaults to 0.1) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
residual_in_fp32 (bool, optional, defaults to True) – Whether to compute the residual connection in float32.
time_step_rank (str or int, optional, defaults to “auto”) – The rank of the time step embedding. If set to “auto”, the rank is calculated as math.ceil(self.hidden_size / 16).
time_step_scale (float, optional, defaults to 1.0) – The scale factor for the time step embedding.
time_step_min (float, optional, defaults to 0.001) – The minimum value for the time step embedding.
time_step_max (float, optional, defaults to 0.1) – The maximum value for the time step embedding.
time_step_init_scheme (str, optional, defaults to “random”) – The initialization scheme for the time step embedding. Possible values are “random” and “uniform”.
time_step_floor (float, optional, defaults to 1e-4) – The floor value for the time step embedding.
rescale_prenorm_residual (bool, optional, defaults to False) – Whether to rescale the pre-norm residual.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for distributing the Mamba model parameters across multiple devices.
These rules define how parameters should be partitioned when using techniques like Fully Sharded Data Parallelism (FSDP), Sharded Parallelism (SP), and Tensor Parallelism (TP). Each rule consists of a regex pattern matching parameter names and a corresponding PartitionSpec.
- Returns
- A tuple of tuples where each inner tuple contains:
A regex pattern matching parameter names
A PartitionSpec object specifying how to partition matching parameters
- Return type
tuple
- model_type: str = 'mamba'#
- class easydel.__init__.MambaForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- prepare_inputs_for_generation(input_ids, max_length, **kwargs)[source]#
The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.
- Parameters
self – Access variables that belong to the class
input_ids – Pass in the input tokens
max_length – Set the length of the sequence to be generated
attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds
- Returns
A dictionary of the past_key_values, attention_mask and position ids
- class easydel.__init__.MambaModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.MistralConfig(vocab_size: int = 32000, hidden_size: int = 4096, intermediate_size: int = 14336, head_dim: int = 128, num_hidden_layers: int = 32, num_attention_heads: int = 32, num_key_value_heads: int = 8, hidden_act='silu', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=None, bos_token_id: int = 1, eos_token_id: int = 2, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling: Dict[str, Union[str, float]] = None, sliding_window=4096, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, number_rep_kv: int = 1, attention_dropout: float = 0.0, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, attention_bias: bool = False, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 32000) – Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 14336) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
head_dim (int, defaults to 128) – Dimensionality of the head for attention.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional, defaults to 8) – Number of key and value heads for each attention layer in the Transformer encoder.
hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
max_position_embeddings (int, optional, defaults to 4096 * 32) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional) – The index of the padding token in the vocabulary.
bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.
eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.
sliding_window (int, optional, defaults to 4096) – The sliding window size.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
number_rep_kv (int, optional, defaults to 1) – Number of repetitions for the key and value vectors.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.
scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.
bits (int, optional) – The number of bits to quantize the model to.
attention_bias (bool, optional, defaults to False) – Whether to use bias in the attention layer.
- attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, attention_dropout: float = 0.0, rope_scaling: Dict[str, Union[str, float]] = None, attention_bias: bool = False, **kwargs)[source]#
The attach_custom_arguments function adds the following arguments to the model:
- Parameters
self – Bind the attributes and methods of a class to an instance of that class
gradient_checkpointing – str: Determine whether to use gradient checkpointing
use_scan_mlp – bool: Determine whether to use the scan_mlp function or notn
scan_mlp_chunk_size – int: Chunk the input to the mlp
number_rep_kv – int: Control the number of times that the key and value vectors are repeated
bits – tp.Optional[int]: Specify the number of bits to use for quantization
attention_dropout – float: Set the dropout rate for the attention layer
attention_bias – bool: when ever to use attention_bias
rope_scaling – tp.Dict[str, tp.Union[str, float]]: rope_scaling for rope
- Return type
A tuple of the following
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
- property granted_mask_max_position_embedding: int#
- model_type: str = 'mistral'#
- class easydel.__init__.MistralForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleMistral model with a language modeling head for causal language modeling tasks.
This model is a transformer-based language model with sliding window attention applied to perform autoregressive language generation.
- config#
Configuration for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- class easydel.__init__.MistralForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleMistral model for sequence classification tasks.
This class extends the base Mistral model by adding a linear classification head to perform sequence classification tasks such as sentiment analysis or text classification.
- config#
Configuration for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- class easydel.__init__.MistralModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleMistral model implementation.
This implements the Mistral language model architecture, utilizing transformer blocks with RMSNorm, sliding window attention, and rotary position embeddings.
- config#
Configuration for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- class easydel.__init__.MixtralConfig(vocab_size=32000, hidden_size=4096, intermediate_size=14336, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, hidden_act='silu', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=1000000.0, sliding_window=4096, attention_dropout=0.0, num_experts_per_tok=2, num_local_experts=8, output_router_logits=False, router_aux_loss_coef=0.001, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, rope_scaling: Dict[str, Union[str, float]] = None, attention_bias: bool = False, initialization_of_moe: bool = False, router_jitter_noise=0.0, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 32000) – Vocabulary size of the Mixtral model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 14336) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional, defaults to 8) – Number of key and value heads for each attention layer in the Transformer encoder.
hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
max_position_embeddings (int, optional, defaults to 4096 * 32) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional) – The index of the padding token in the vocabulary.
bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.
eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (float, optional, defaults to 1e6) – The theta value to use for rotary position embeddings.
sliding_window (int, optional, defaults to 4096) – The sliding window size.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
num_experts_per_tok (int, optional, defaults to 2) – The number of experts per token.
num_local_experts (int, optional, defaults to 8) – The number of local experts.
output_router_logits (bool, optional, defaults to False) – Whether to output router logits.
router_aux_loss_coef (float, optional, defaults to 0.001) – The router auxiliary loss coefficient.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.
scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.
number_rep_kv (int, optional, defaults to 1) – Number of repetitions for the key and value vectors.
bits (int, optional) – The number of bits to quantize the model to.
rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.
attention_bias (bool, optional, defaults to False) – Whether to use bias in the attention layer.
initialization_of_moe (bool, optional, defaults to False) – Whether to initialize the MoE layers.
router_jitter_noise (float, optional, defaults to 0.0) – The jitter noise for the router.
- attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, attention_dropout: float = 0.0, rope_scaling: Dict[str, Union[str, float]] = None, attention_bias: bool = False, initialization_of_moe: bool = False, **kwargs)[source]#
The attach_custom_arguments function adds the following arguments to the model:
- Parameters
gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.
use_scan_mlp (bool, optional) – Whether to use scan for MLP layers. Defaults to False.
scan_mlp_chunk_size (int, optional) – Chunk size for scan MLP. Defaults to 1024.
number_rep_kv (int, optional) – Number of repetitions for key/value heads. Defaults to 1.
bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.
attention_dropout (float, optional) – Dropout probability for attention. Defaults to 0.0.
rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – RoPE scaling configuration. Defaults to None.
attention_bias (bool, optional) – Whether to use bias in attention layers. Defaults to False.
initialization_of_moe (bool, optional) – Whether MoE layers are being initialized. Defaults to False.
**kwargs – Additional keyword arguments (ignored).
- Return type
A tuple of the following
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.
- Parameters
*args – Additional positional arguments (unused).
**kwargs – Additional keyword arguments (unused).
- Returns
- A tuple of partition rules, where each rule is a tuple
containing a regex pattern for parameter names and the corresponding PartitionSpec.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- static get_weight_decay_exclusions()[source]#
Returns a tuple of parameter names for which weight decay should be excluded.
- Returns
An empty tuple, indicating no specific weight decay exclusions for this model.
- Return type
tuple
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size specifically for frequency-based position embeddings.
If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for frequency encoding.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size specifically for mask-based position embeddings.
If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for mask encoding.
- Return type
int
- model_type: str = 'mixtral'#
- class easydel.__init__.MixtralForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleMixtral model with a Causal Language Modeling head.
This model consists of the base Mixtral transformer (MixtralModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. It also handles the calculation of the auxiliary loss from the MoE layers.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core Mixtral transformer model.
- Type
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- num_experts#
Total number of experts.
- Type
int
- num_experts_per_tok#
Number of experts to route per token.
- Type
int
- class easydel.__init__.MixtralForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleMixtral model with a Sequence Classification head.
This model consists of the base Mixtral transformer (MixtralModel) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the first token) to the number of classes for classification. It also handles the calculation of the auxiliary loss from the MoE layers.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core Mixtral transformer model.
- Type
- score#
The linear layer for classification.
- Type
- num_experts#
Total number of experts.
- Type
int
- num_experts_per_tok#
Number of experts to route per token.
- Type
int
- class easydel.__init__.MixtralModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Mixtral model transformer.
This class represents the core transformer architecture of the Mixtral model, consisting of an embedding layer, multiple MixtralDecoderLayer layers (with sparse MoE), and a final layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[MixtralDecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.
- class easydel.__init__.MptAttentionConfig(attn_type='multihead_attention', attn_pdrop=0, attn_impl='torch', clip_qkv=None, softmax_scale=None, prefix_lm=False, qk_ln=False, attn_uses_sequence_id=False, alibi=True, alibi_bias_max=8, **kwargs)[source]#
Bases:
EasyDeLBaseConfigThis is the configuration class to store the attention related configuration of a [MptModel].
- Parameters
attn_type (str, optional, defaults to “multihead_attention”) – The type of attention to use. Can be either “multihead_attention” or “multiquery_attention”.
attn_pdrop (float, optional, defaults to 0.0) – The dropout probability applied to the attention output.
attn_impl (str, optional, defaults to “torch”) – The implementation of the attention mechanism. Can be either “torch” or “flash”.
clip_qkv (float, optional) – The clip value applied to the query, key, and value tensors.
softmax_scale (float, optional) – The scale factor applied to the softmax function in the attention layer.
prefix_lm (bool, optional, defaults to False) – Whether to use a prefix LM.
qk_ln (bool, optional, defaults to False) – Whether to apply layer normalization to the query and key tensors.
attn_uses_sequence_id (bool, optional, defaults to False) – Whether the attention layer uses sequence IDs.
alibi (bool, optional, defaults to True) – Whether to use the ALiBi (Attention with Linear Biases) method.
alibi_bias_max (int, optional, defaults to 8) – The maximum value for the ALiBi bias.
- classmethod from_pretrained(pretrained_model_name_or_path, **kwargs) EasyDeLBaseConfig[source]#
Loads attention configuration from a pretrained model configuration file.
- Parameters
cls (type) – The class itself.
pretrained_model_name_or_path (str) – Path or identifier of the pretrained model.
**kwargs – Additional keyword arguments passed to get_config_dict and from_dict.
- Returns
An instance of MptAttentionConfig loaded from the pretrained model.
- Return type
- class easydel.__init__.MptConfig(d_model: int = 2048, n_heads: int = 16, n_layers: int = 24, expansion_ratio: int = 4, max_seq_len: int = 2048, vocab_size: int = 50368, resid_prob_drop: float = 0.0, layer_norm_epsilon: float = 1e-05, emb_prob_drop: float = 0.0, learned_pos_emb: bool = True, attn_config: Optional[MptAttentionConfig] = None, init_device: str = 'cpu', logit_scale: Optional[Union[float, str]] = None, no_bias: bool = True, verbose: int = 0, embedding_fraction: float = 1.0, norm_type: str = 'low_precision_layernorm', use_cache: bool = False, initializer_range=0.02, alibi: bool = True, use_bias: bool = False, act_fn: str = 'gelu', qk_ln: bool = False, use_lm_head: bool = False, use_norm_bias: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
d_model (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.
n_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.
n_layers (int, optional, defaults to 24) – Number of hidden layers in the Transformer encoder.
expansion_ratio (int, optional, defaults to 4) – Expansion ratio of the feed-forward layer.
max_seq_len (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
vocab_size (int, optional, defaults to 50368) – Vocabulary size of the MPT model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
resid_prob_drop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.
emb_prob_drop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.
learned_pos_emb (bool, optional, defaults to True) – Whether to learn positional embeddings.
attn_config ([MptAttentionConfig], optional) – The configuration of the attention layer.
init_device (str, optional, defaults to “cpu”) – The device to initialize the model on.
logit_scale (float or str, optional) – The logit scale. If set to “inv_sqrt_d_model”, the logit scale is calculated as 1 / math.sqrt(d_model).
no_bias (bool, optional, defaults to True) – Whether to use bias in the linear layers.
verbose (int, optional, defaults to 0) – The verbosity level.
embedding_fraction (float, optional, defaults to 1.0) – The fraction of the embedding matrix to use.
norm_type (str, optional, defaults to “low_precision_layernorm”) – The type of layer normalization to use.
use_cache (bool, optional, defaults to False) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
alibi (bool, optional, defaults to True) – Whether to use ALiBi (Attention with Linear Biases) method.
use_bias (bool, optional, defaults to False) – Whether to use bias in the linear layers.
act_fn (str, optional, defaults to “gelu”) – The activation function to use.
qk_ln (bool, optional, defaults to False) – Whether to apply layer normalization to the query and key tensors.
use_lm_head (bool, optional, defaults to False) – Whether to use a language modeling head.
use_norm_bias (bool, optional, defaults to False) – Whether to use bias in the layer normalization layers.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
bits (int, optional) – The number of bits to quantize the model to.
- attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
Attaches custom arguments to the configuration object.
This method allows adding or overriding configuration attributes dynamically. It primarily sets attributes related to gradient checkpointing and quantization bits.
- Parameters
gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.
bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.
**kwargs – Additional keyword arguments (ignored).
- attribute_map: dict[str, str] = {'hidden_size': 'd_model', 'max_position_embeddings': 'max_seq_len', 'num_attention_heads': 'n_heads', 'num_hidden_layers': 'n_layers', 'tie_word_embeddings': 'use_lm_head'}#
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.
- Parameters
*args – Additional positional arguments (unused).
**kwargs – Additional keyword arguments (unused).
- Returns
- A tuple of partition rules, where each rule is a tuple
containing a regex pattern for parameter names and the corresponding PartitionSpec.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size specifically for frequency-based position embeddings.
If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_seq_len.
- Returns
The granted maximum position embedding size for frequency encoding.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size specifically for mask-based position embeddings.
If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_seq_len.
- Returns
The granted maximum position embedding size for mask encoding.
- Return type
int
- model_type: str = 'mpt'#
- class easydel.__init__.MptForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleMPT model with a language modeling head.
This model extends the base MptModel by adding a linear layer (lm_head) on top to predict the next token in a sequence, making it suitable for causal language modeling tasks.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- lm_head#
The language modeling head. If use_lm_head in the config is True (tying embeddings), this will be None.
- Type
ParallelLinear, optional
- class easydel.__init__.MptModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleMPT model implementation.
This class implements the main MPT transformer model architecture, consisting of an embedding layer (token and optional positional), multiple MptBlock layers, and a final layer normalization.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- wte#
Token embedding layer.
- Type
nn.Embed
- emb_drop#
Dropout layer applied after embeddings.
- Type
nn.Dropout
- norm_f#
Final layer normalization.
- Type
nn.LayerNorm
- alibi#
Precomputed ALiBi tensor if using ALiBi.
- Type
chex.Array, optional
- property alibi#
- class easydel.__init__.OPTConfig(vocab_size: int = 50272, hidden_size: int = 768, num_hidden_layers: int = 12, ffn_dim: int = 3072, max_position_embeddings: int = 2048, do_layer_norm_before: bool = True, _remove_final_layer_norm: bool = False, word_embed_proj_dim: int = None, dropout: float = 0.1, attention_dropout: float = 0.0, num_attention_heads: int = 12, activation_function: str = 'relu', layerdrop: float = 0.0, init_std: float = 0.02, use_cache: bool = True, pad_token_id: int = 1, bos_token_id: int = 2, eos_token_id: int = 2, enable_bias: bool = True, layer_norm_elementwise_affine: bool = True, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 50272) – Vocabulary size of the OPT model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.
ffn_dim (int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
do_layer_norm_before (bool, optional, defaults to True) – Whether to perform layer normalization before the attention block.
_remove_final_layer_norm (bool, optional, defaults to False) – Whether to remove the final layer norm.
word_embed_proj_dim (int, optional) – The dimension of the word embedding projection. If not provided, it will default to hidden_size.
dropout (float, optional, defaults to 0.1) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
num_attention_heads (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.
activation_function (str or function, optional, defaults to “relu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
layerdrop (float, optional, defaults to 0.0) – The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more details.
init_std (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional, defaults to 1) – The index of the padding token in the vocabulary.
bos_token_id (int, optional, defaults to 2) – The id of the beginning-of-sequence token.
eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.
enable_bias (bool, optional, defaults to True) – Whether to use bias in the linear layers.
layer_norm_elementwise_affine (bool, optional, defaults to True) – Whether to use elementwise affine in the layer normalization layers.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
- attach_custom_arguments(vocab_size: int = 50272, hidden_size: int = 768, num_hidden_layers: int = 12, ffn_dim: int = 3072, max_position_embeddings: int = 2048, do_layer_norm_before: bool = True, _remove_final_layer_norm: bool = False, word_embed_proj_dim: int = None, dropout: float = 0.1, attention_dropout: float = 0.0, num_attention_heads: int = 12, activation_function: str = 'relu', layerdrop: float = 0.0, init_std: float = 0.02, use_cache: bool = True, pad_token_id: int = 1, bos_token_id: int = 2, eos_token_id: int = 2, enable_bias: bool = True, layer_norm_elementwise_affine: bool = True, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
Attaches custom arguments to the configuration object.
This method allows dynamically adding or overriding configuration attributes. It iterates through the provided arguments and sets them as attributes of the configuration object if they don’t already exist.
- Parameters
vocab_size (int, optional) – Vocabulary size. Defaults to 50272.
hidden_size (int, optional) – Dimensionality of the encoder layers. Defaults to 768.
num_hidden_layers (int, optional) – Number of hidden layers. Defaults to 12.
ffn_dim (int, optional) – Dimensionality of the feed-forward layer. Defaults to 3072.
max_position_embeddings (int, optional) – Maximum sequence length. Defaults to 2048.
do_layer_norm_before (bool, optional) – Whether to apply layer norm before attention. Defaults to True.
_remove_final_layer_norm (bool, optional) – Whether to remove the final layer norm. Defaults to False.
word_embed_proj_dim (int, optional) – Dimension of the word embedding projection. Defaults to hidden_size.
dropout (float, optional) – Dropout probability. Defaults to 0.1.
attention_dropout (float, optional) – Attention dropout probability. Defaults to 0.0.
num_attention_heads (int, optional) – Number of attention heads. Defaults to 12.
activation_function (str, optional) – Activation function name. Defaults to “relu”.
layerdrop (float, optional) – LayerDrop probability. Defaults to 0.0.
init_std (float, optional) – Initialization standard deviation. Defaults to 0.02.
use_cache (bool, optional) – Whether to use key/value cache. Defaults to True.
pad_token_id (int, optional) – Padding token ID. Defaults to 1.
bos_token_id (int, optional) – Beginning-of-sequence token ID. Defaults to 2.
eos_token_id (int, optional) – End-of-sequence token ID. Defaults to 2.
enable_bias (bool, optional) – Whether to use bias in linear layers. Defaults to True.
layer_norm_elementwise_affine (bool, optional) – Whether layer norm uses elementwise affine parameters. Defaults to True.
gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.
**kwargs – Additional keyword arguments to attach.
- get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#
Get the partition rules for the model.
- Parameters
fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.
- Returns
The partition rules.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- keys_to_ignore_at_inference = ['past_key_values']#
- model_type: str = 'opt'#
- class easydel.__init__.OPTForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleOPT Model with a Causal Language Modeling head.
This model consists of the base OPTModel followed by a linear layer (the language modeling head) to predict the next token logits.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- prepare_inputs_for_generation(input_ids, max_length, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.
- Parameters
self – Access variables that belong to the class
input_ids – Pass in the input tokens
max_length – Set the length of the sequence to be generated
attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds
- Returns
A dictionary of the past_key_values, attention_mask and position ids
- class easydel.__init__.OPTModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleBase OPT Model class.
This class represents the core OPT model architecture, consisting primarily of the OPTDecoder.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- decoder#
The OPT decoder stack.
- Type
- class easydel.__init__.ORPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 1e-06, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'ORPOTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, max_length: ~typing.Optional[int] = 1024, max_prompt_length: ~typing.Optional[int] = 512, max_completion_length: ~typing.Optional[int] = None, beta: float = 0.1, disable_dropout: bool = True, label_pad_token_id: int = -100, padding_value: ~typing.Optional[int] = None, generate_during_eval: bool = False, is_encoder_decoder: ~typing.Optional[bool] = None, dataset_num_proc: ~typing.Optional[int] = None)[source]#
Bases:
TrainingArgumentsConfiguration class for ORPO training settings.
This class inherits from TrainingArguments and holds configuration parameters specific to the ORPO model training. The dataclass automatically generates an initializer, and the __post_init__ method further processes some of the parameters after object initialization.
- model_name#
The name of the model. Default is “ORPOTrainer”.
- Type
str
- learning_rate#
The learning rate used during training. Default is 1e-6.
- Type
float
- max_length#
The maximum allowed sequence length for the input. Default is 1024.
- Type
Optional[int]
- max_prompt_length#
The maximum allowed length of the prompt portion of the input. Default is 512.
- Type
Optional[int]
- max_completion_length#
The maximum allowed length of the completion. If not provided, it is set to max_length - max_prompt_length.
- Type
Optional[int]
- beta#
A hyperparameter beta, with a default value of 0.1.
- Type
float
- disable_dropout#
Flag to disable dropout during training. Default is True.
- Type
bool
- label_pad_token_id#
The token id used for padding labels. Default is -100.
- Type
int
- padding_value#
The value used for padding sequences. Default is None.
- Type
Optional[int]
- generate_during_eval#
Flag indicating whether to generate sequences during evaluation. Default is False.
- Type
bool
- is_encoder_decoder#
Flag to indicate if the model is encoder-decoder. Default is None.
- Type
Optional[bool]
- model_init_kwargs#
Additional keyword arguments for model initialization. Default is None.
- Type
Optional[Dict[str, Any]]
- dataset_num_proc#
Number of processes to use for dataset processing. Default is None.
- Type
Optional[int]
- max_sequence_length#
Computed attribute representing the maximum sequence length used for training. It is set in the __post_init__ method.
- Type
int
- beta: float = 0.1#
- dataset_num_proc: Optional[int] = None#
- disable_dropout: bool = True#
- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- generate_during_eval: bool = False#
- is_encoder_decoder: Optional[bool] = None#
- label_pad_token_id: int = -100#
- learning_rate: float = 1e-06#
- max_completion_length: Optional[int] = None#
- max_length: Optional[int] = 1024#
- max_prompt_length: Optional[int] = 512#
- model_name: str = 'ORPOTrainer'#
- padding_value: Optional[int] = None#
- replace(**kwargs)#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.
- class easydel.__init__.ORPOTrainer(arguments: ORPOConfig, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, data_collator: Optional[DPODataCollatorWithPadding] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, processing_class: Optional[Any] = None)[source]#
Bases:
Trainer- arguments: ORPOConfig#
- build_tokenized_answer(prompt: str, answer: str) Dict[str, ndarray][source]#
Tokenizes a prompt and answer pair, handling special tokens and padding/truncation.
- Parameters
prompt (str) – The prompt text.
answer (str) – The answer text.
- Returns
A dictionary containing the tokenized prompt and answer, along with attention masks.
- Return type
tp.Dict[str, np.ndarray]
- Raises
ValueError – If there’s a mismatch in token lengths.
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- This method sets up the necessary functions for training and evaluation, including:
Initialization of the model state.
Sharding of the model parameters and optimizer state.
JIT-compilation of the training and evaluation step functions.
- Returns
An object containing the configured functions and other relevant information.
- Return type
- create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Creates a data collection function for batching.
For DPO training, this method simply returns the pre-configured data_collator.
- Parameters
max_sequence_length (int) – The maximum sequence length (not used in this implementation).
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.
- Returns
The data collator function.
- Return type
tp.Callable
- tokenize_row(feature: Dict[str, str], state: Optional[object] = None) Dict[str, ndarray][source]#
Tokenizes a single row of data from the ORPO dataset.
This method tokenizes the prompt, chosen response, and rejected response, handles padding and truncation, and prepares the data for input to the DPO model.
- Parameters
feature (tp.Dict) – A dictionary containing the “prompt”, “chosen”, and “rejected” texts.
state (EasyDeLState, optional) – Not used in this implementation. Defaults to None.
- Returns
- A dictionary containing the tokenized prompt, chosen response, and rejected response,
along with attention masks and labels.
- Return type
tp.Dict
- Raises
ValueError – If the input data types are incorrect.
- class easydel.__init__.Olmo2Config(vocab_size=50304, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=2048, initializer_range=0.02, use_cache=True, pad_token_id=1, bos_token_id=None, eos_token_id=50279, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, rms_norm_eps=1e-05, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigThis is the configuration class to store the configuration of a [Olmo2Model]. It is used to instantiate an OLMo2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf).
Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 50304) – Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [Olmo2Model]
hidden_size (int, optional, defaults to 4096) – Dimension of the hidden representations.
intermediate_size (int, optional, defaults to 11008) – Dimension of the MLP representations.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer decoder.
num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (int, optional) – This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to num_attention_heads.
hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) in the decoder.
max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional, defaults to 1) – Padding token id.
bos_token_id (int, optional) – Beginning of stream token id.
eos_token_id (int, optional, defaults to 50279) – End of stream token id.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie weight embeddings
rope_theta (float, optional, defaults to 10000.0) – The base period of the RoPE embeddings.
rope_scaling (tp.Dict, optional) – Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is {“type”: strategy name, “factor”: scaling factor}. When using this flag, don’t update max_position_embeddings to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions.
attention_bias (bool, defaults to False, optional, defaults to False) – Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
rms_norm_eps (float, optional, defaults to 1e-05) – The epsilon used by the rms normalization layers.
>>> from transformers import Olmo2Model, Olmo2Config
>>> # Initializing a Olmo2 7B style configuration >>> configuration = Olmo2Config()
>>> # Initializing a model from the Olmo2 7B style configuration >>> model = Olmo2Model(configuration)
>>> # Accessing the model configuration >>> configuration = model.config
- attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None)[source]#
Attaches custom arguments to the configuration object.
This method allows adding or overriding configuration attributes dynamically. It primarily sets attributes related to gradient checkpointing, MLP scanning, and quantization bits.
- Parameters
gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.
use_scan_mlp (bool, optional) – Whether to use scan for MLP layers. Defaults to False.
scan_mlp_chunk_size (int, optional) – Chunk size for scan MLP. Defaults to 1024.
bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.
- Parameters
*args – Additional positional arguments (unused).
**kwargs – Additional keyword arguments (unused).
- Returns
- A tuple of partition rules, where each rule is a tuple
containing a regex pattern for parameter names and the corresponding PartitionSpec.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size specifically for frequency-based position embeddings.
If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for frequency encoding.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size specifically for mask-based position embeddings.
If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for mask encoding.
- Return type
int
- keys_to_ignore_at_inference = ['past_key_values']#
- model_type: str = 'olmo2'#
- class easydel.__init__.Olmo2ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleOLMo-2 model with a Causal Language Modeling head.
This model consists of the base OLMo-2 transformer (Olmo2Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core OLMo-2 transformer model.
- Type
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.__init__.Olmo2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleOLMo-2 model with a Sequence Classification head.
This model consists of the base OLMo-2 transformer (Olmo2Model) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token) to the number of classes for classification.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core OLMo-2 transformer model.
- Type
- score#
The linear layer for classification.
- Type
- class easydel.__init__.Olmo2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base OLMo-2 model transformer.
This class represents the core transformer architecture of the OLMo-2 model, consisting of an embedding layer, multiple Olmo2DecoderLayer layers, and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[Olmo2DecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.
- class easydel.__init__.OlmoConfig(vocab_size=50304, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=2048, initializer_range=0.02, use_cache=True, pad_token_id=1, bos_token_id=None, eos_token_id=50279, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, clip_qkv=None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 50304) – Vocabulary size of the Olmo model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 11008) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.
hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional, defaults to 1) – The index of the padding token in the vocabulary.
bos_token_id (int, optional) – The id of the beginning-of-sequence token.
eos_token_id (int, optional, defaults to 50279) – The id of the end-of-sequence token.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.
attention_bias (bool, optional, defaults to False) – Whether to use attention bias.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
clip_qkv (float, optional) – The clip value applied to the query, key, and value tensors.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.
scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.
bits (int, optional) – The number of bits to quantize the model to.
- attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None)[source]#
Attaches custom arguments to the configuration object.
This method allows adding or overriding configuration attributes dynamically. It primarily sets attributes related to gradient checkpointing, MLP scanning, and quantization bits.
- Parameters
gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.
use_scan_mlp (bool, optional) – Whether to use scan for MLP layers. Defaults to False.
scan_mlp_chunk_size (int, optional) – Chunk size for scan MLP. Defaults to 1024.
bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.
- Parameters
*args – Additional positional arguments (unused).
**kwargs – Additional keyword arguments (unused).
- Returns
- A tuple of partition rules, where each rule is a tuple
containing a regex pattern for parameter names and the corresponding PartitionSpec.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size specifically for frequency-based position embeddings.
If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for frequency encoding.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size specifically for mask-based position embeddings.
If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for mask encoding.
- Return type
int
- model_type: str = 'olmo'#
- class easydel.__init__.OlmoForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleOLMo model with a Causal Language Modeling head.
This model consists of the base OLMo transformer (OlmoModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.__init__.OlmoModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base OLMo model transformer.
This class represents the core transformer architecture of the OLMo model, consisting of an embedding layer and multiple OlmoDecoderLayer layers. Note that OLMo does not have a final layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[OlmoDecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.
- class easydel.__init__.OpenELMConfig(vocab_size: int = 32000, max_context_length: int = 2048, num_transformer_layers: int = 12, model_dim: int = 2048, head_dim: int = 128, qkv_multipliers: Union[Number, List[Number]] = 1.0, num_query_heads: Optional[int] = None, num_gqa_groups: int = 1, ffn_multipliers: Union[Number, List[Number]] = 4.0, ffn_with_glu: bool = True, ffn_dim_divisor: int = 256, activation_fn_name: str = 'swish', normalization_layer_name: str = 'rms_norm', normalize_qk_projections: bool = False, share_input_output_layers: bool = False, rope_freq_constant: int = 10000, rope_max_length: int = 4096, initializer_range: float = 0.02, use_cache: bool = True, bos_token_id: int = 1, eos_token_id: int = 2, rope_scaling: Dict[str, Union[str, float]] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 32000) – Vocabulary size of the OpenELM model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
max_context_length (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
num_transformer_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.
model_dim (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.
head_dim (int, optional, defaults to 128) – Dimensionality of the attention heads.
qkv_multipliers (float or list of float, optional, defaults to 1.0) – The multiplier for the query, key, and value projections.
num_query_heads (int, optional) – Number of query heads. If not provided, it will be calculated based on model_dim and head_dim.
num_gqa_groups (int, optional, defaults to 1) – Number of GQA (Grouped Query Attention) groups.
ffn_multipliers (float or list of float, optional, defaults to 4.0) – The multiplier for the feed-forward network.
ffn_with_glu (bool, optional, defaults to True) – Whether to use a gated linear unit (GLU) in the feed-forward network.
ffn_dim_divisor (int, optional, defaults to 256) – The divisor for the feed-forward network dimension.
activation_fn_name (str, optional, defaults to “swish”) – The activation function to use.
normalization_layer_name (str, optional, defaults to “rms_norm”) – The normalization layer to use.
normalize_qk_projections (bool, optional, defaults to False) – Whether to normalize the query and key projections.
share_input_output_layers (bool, optional, defaults to False) – Whether to share the input and output layers.
rope_freq_constant (int, optional, defaults to 10000) – The frequency constant for Rotary Position Embeddings (RoPE).
rope_max_length (int, optional, defaults to 4096) – The maximum length for RoPE.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.
eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.
rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.
scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.
bits (int, optional) – The number of bits to quantize the model to.
- attribute_map: dict[str, str] = {'tie_word_embedding': 'share_input_output_layers'}#
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.
- Parameters
*args – Additional positional arguments (unused).
**kwargs – Additional keyword arguments (unused).
- Returns
- A tuple of partition rules, where each rule is a tuple
containing a regex pattern for parameter names and the corresponding PartitionSpec.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- static get_weight_decay_exclusions()[source]#
Returns a tuple of parameter names for which weight decay should be excluded.
- Returns
A tuple containing ‘bias’, ‘normalization’, and ‘emb’ as exclusions.
- Return type
tuple
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size specifically for frequency-based position embeddings.
If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_context_length.
- Returns
The granted maximum position embedding size for frequency encoding.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size specifically for mask-based position embeddings.
If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_context_length.
- Returns
The granted maximum position embedding size for mask encoding.
- Return type
int
- model_type: str = 'openelm'#
- class easydel.__init__.OpenELMForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleOpenELM model with a Causal Language Modeling head.
This model consists of the base OpenELM transformer (OpenELMModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- transformer#
The core OpenELM transformer model.
- Type
- lm_head#
The linear layer for projecting hidden states to vocabulary logits. This is None if config.share_input_output_layers is True.
- Type
ParallelLinear, optional
- class easydel.__init__.OpenELMModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base OpenELM model transformer.
This class represents the core transformer architecture of the OpenELM model, consisting of an embedding layer, multiple OpenELMDecoderLayer layers, and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- token_embeddings#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[OpenELMDecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.
- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- class easydel.__init__.PartitionAxis(data_parallel_axis: str = 'dp', fully_sharded_data_parallel_axis: str = 'fsdp', tensor_parallel_axis: str = 'tp', sequence_parallel_axis: str = 'sp', expert_parallel_axis: str = 'ep', batch_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, query_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, head_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, key_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, hidden_state_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, mlp_intermediate_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, vocab_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, expert_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, expert_gate_axis: Optional[Union[Tuple[str, ...], str, Any]] = None, attention_dim_axis: Optional[Union[Tuple[str, ...], str, Any]] = None, bias_head_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = None, bias_key_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = None, generation_batch_axis: Optional[Union[Tuple[str, ...], str, Any]] = None, generation_query_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = None, generation_head_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, generation_key_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis, generation_attention_dim_axis: Optional[Union[Tuple[str, ...], str, Any]] = None)[source]#
Bases:
objectConfiguration for partitioning model axes across a device mesh.
Defines the mesh dimension names for standard parallelism strategies and maps logical model axes to these dimensions. Allows overriding defaults.
- Mesh Dimensions:
data_parallel_axis: Name for data parallel mesh dim. Default: “dp”. fully_sharded_data_parallel_axis: Name for FSDP mesh dim. Default: “fsdp”. tensor_parallel_axis: Name for tensor parallel mesh dim. Default: “tp”. sequence_parallel_axis: Name for sequence parallel mesh dim. Default: “sp”. expert_parallel_axis: Name for expert parallel mesh dim (MoE). Default: “ep”.
- Logical Model Axes:
Maps logical tensor axes (like batch, sequence, hidden) to one or more mesh dimension names defined above, or None if not partitioned. Defaults are derived from the standard mesh dimension names but can be overridden during instantiation. For example, head_axis defaults to the value of tensor_parallel_axis (‘tp’).
- attention_dim_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
- batch_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
- bias_head_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
- bias_key_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
- data_parallel_axis: str = 'dp'#
- expert_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
- expert_gate_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
- expert_parallel_axis: str = 'ep'#
- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- fully_sharded_data_parallel_axis: str = 'fsdp'#
- generation_attention_dim_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
- generation_batch_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
- generation_head_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
- generation_key_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
- generation_query_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = None#
- head_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
- key_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
- mlp_intermediate_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
- query_sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
- replace(**kwargs)#
- sequence_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
- sequence_parallel_axis: str = 'sp'#
- tensor_parallel_axis: str = 'tp'#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.
- vocab_axis: Optional[Union[Tuple[str, ...], str, Any]] = Ellipsis#
- class easydel.__init__.Phi3Config(vocab_size=32064, hidden_size=3072, intermediate_size=8192, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, resid_pdrop=0.0, embd_pdrop=0.0, attention_dropout=0.0, hidden_act='silu', max_position_embeddings=4096, original_max_position_embeddings=4096, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, bos_token_id=1, eos_token_id=32000, pad_token_id=32000, sliding_window=None, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 32064) – Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 3072) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 8192) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.
resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
embd_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
max_position_embeddings (int, optional, defaults to 4096) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
original_max_position_embeddings (int, optional, defaults to 4096) – The original maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.
bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.
eos_token_id (int, optional, defaults to 32000) – The id of the end-of-sequence token.
pad_token_id (int, optional, defaults to 32000) – The index of the padding token in the vocabulary.
sliding_window (int, optional) – The sliding window size.
bits (int, optional) – The number of bits to quantize the model to.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model.
- Returns
The partition rules.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size specifically for frequency-based position embeddings.
If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for frequency encoding.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size specifically for mask-based position embeddings.
If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for mask encoding.
- Return type
int
- model_type: str = 'phi3'#
- class easydel.__init__.Phi3ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModulePhi-3 model with a Causal Language Modeling head.
This model consists of the base Phi-3 transformer (Phi3Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.__init__.Phi3Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Phi-3 model transformer.
This class represents the core transformer architecture of the Phi-3 model, consisting of an embedding layer, multiple Phi3DecoderLayer layers, and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- embed_dropout#
Dropout layer applied after embeddings.
- Type
nn.Dropout
- layers#
List of decoder layers.
- Type
tp.List[Phi3DecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.
- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- class easydel.__init__.PhiConfig(vocab_size=51200, hidden_size=2048, intermediate_size=8192, num_hidden_layers=24, num_attention_heads=32, num_key_value_heads=None, resid_pdrop=0.0, embd_pdrop=0.0, attention_dropout=0.0, hidden_act='gelu_new', max_position_embeddings=2048, initializer_range=0.02, layer_norm_eps=1e-05, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, partial_rotary_factor=0.5, qk_layernorm=False, bos_token_id=1, eos_token_id=2, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 51200) – Vocabulary size of the Phi model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 8192) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 24) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.
resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
embd_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
hidden_act (str or function, optional, defaults to “gelu_new”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.
partial_rotary_factor (float, optional, defaults to 0.5) – The factor for partial rotary embeddings.
qk_layernorm (bool, optional, defaults to False) – Whether to apply layer normalization to the query and key tensors.
bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.
eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.
bits (int, optional) – The number of bits to quantize the model to.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
- attribute_map: dict[str, str] = {'hidden_size': 'n_embd', 'max_position_embeddings': 'n_positions', 'num_attention_heads': 'num_attention_heads', 'num_hidden_layers': 'num_hidden_layers'}#
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model.
- Returns
The partition rules.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size specifically for frequency-based position embeddings.
If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for frequency encoding.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size specifically for mask-based position embeddings.
If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for mask encoding.
- Return type
int
- model_type: str = 'phi'#
- class easydel.__init__.PhiForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModulePhi model with a Causal Language Modeling head.
This model consists of the base Phi transformer (PhiModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.__init__.PhiModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Phi model transformer.
This class represents the core transformer architecture of the Phi model, consisting of an embedding layer, multiple PhiDecoderLayer layers, and a final layer normalization.
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[PhiDecoderLayer]
- final_layernorm#
Final layer normalization.
- Type
nn.LayerNorm
- embed_dropout#
Dropout layer applied after embeddings.
- Type
nn.Dropout
- gradient_checkpointing#
Gradient checkpointing configuration.
- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- class easydel.__init__.PhiMoeConfig(vocab_size=32064, hidden_size=4096, intermediate_size=6400, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, hidden_act='silu', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=1000000.0, rope_scaling=None, sliding_window=None, attention_dropout=0.0, num_experts_per_tok=2, num_local_experts=16, output_router_logits=False, router_aux_loss_coef=0.001, router_jitter_noise=0.01, input_jitter_noise=0.0, attention_bias=False, embd_pdrop: float = 0.0, lm_head_bias=False, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 32064) – Vocabulary size of the PhiMoE model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [PhiMoEModel]
hidden_size (int, optional, defaults to 4096) – Dimension of the hidden representations.
intermediate_size (int, optional, defaults to 6400) – Dimension of the MLP representations.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional, defaults to 8) – This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8.
hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) in the decoder.
max_position_embeddings (int, optional, defaults to 4096*32) – The maximum sequence length that this model might ever be used with. Mixtral’s sliding window attention allows sequence of up to 4096*32 tokens.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (float, optional, defaults to 1e-05) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional) – The id of the padding token.
bos_token_id (int, optional, defaults to 1) – The id of the “beginning-of-sequence” token.
eos_token_id (int, optional, defaults to 2) – The id of the “end-of-sequence” token.
tie_word_embeddings (bool, optional, defaults to False) – Whether the model’s input and output word embeddings should be tied.
rope_theta (float, optional, defaults to 10000.0) – The base period of the RoPE embeddings.
rope_scaling (dict, optional) – The scaling strategy for the RoPE embeddings. If None, no scaling is applied. If a dictionary, it must contain the following keys: type, short_factor, long_factor, short_mscale, long_mscale and original_max_position_embeddings. The type must be longrope, the short_mscale and long_scale must be numbers, the short_factor and long_factor must be lists of numbers with the same length as half of the attention head size and the original_max_position_embeddings must be an integer.
sliding_window (int, optional) – Sliding window attention window size. If not specified, will default to 262144.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
num_experts_per_tok (int, optional, defaults to 2) – The number of experts to root per-token, can be also interpreted as the top-p routing parameter
num_local_experts (int, optional, defaults to 16) – Number of experts per Sparse MLP layer.
output_router_logits (bool, optional, defaults to False) – Whether or not the router logits should be returned by the model. Enabeling this will also allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (float, optional, defaults to 0.0) – The aux loss factor for the total loss.
router_jitter_noise (float, optional, defaults to 0.01) – Amount of noise to add to the router.
bits (int, optional) – The number of bits to quantize the model to.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
- attach_custom_arguments(bits: Optional[int] = None, embd_pdrop: float = 0.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
Attaches custom arguments to the configuration object.
This method allows dynamically adding or overriding configuration attributes. It primarily sets attributes related to quantization, dropout, and gradient checkpointing. Any additional keyword arguments are also set as attributes if they don’t already exist.
- Parameters
bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.
embd_pdrop (float, optional) – Dropout probability for embeddings. Defaults to 0.0.
gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.
**kwargs – Additional keyword arguments to attach.
- get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#
Get the partition rules for the model.
- Parameters
fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.
- Returns
The partition rules.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size specifically for frequency-based position embeddings.
If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for frequency encoding.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size specifically for mask-based position embeddings.
If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for mask encoding.
- Return type
int
- model_type: str = 'phimoe'#
- class easydel.__init__.PhiMoeForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModulePhiMoE model with a Causal Language Modeling head.
This model consists of the base PhiMoE transformer (PhiMoeModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core PhiMoE transformer model.
- Type
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.__init__.PhiMoeModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base PhiMoE model transformer.
This class represents the core transformer architecture of the PhiMoE model, consisting of an embedding layer, multiple PhiMoeDecoderLayer layers (which include Sparse Mixture of Experts blocks), and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[PhiMoeDecoderLayer]
- embed_dropout#
Dropout layer applied after embeddings.
- Type
nn.Dropout
- gradient_checkpointing#
Gradient checkpointing configuration.
- class easydel.__init__.PixtralVisionConfig(hidden_size: int = 1024, intermediate_size: int = 4096, num_hidden_layers: int = 24, num_attention_heads: int = 16, num_channels: int = 3, image_size: int = 1024, patch_size: int = 16, hidden_act: str = 'gelu', attention_dropout: float = 0.0, rope_theta: float = 10000.0, initializer_range: int = 0.02, **kwargs)[source]#
Bases:
EasyDeLBaseConfigThis is the configuration class to store the configuration of a [PixtralVisionModel]. It is used to instantiate an Pixtral vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to the vision encoder used by Pixtral-12B.
e.g. [pixtral-hf/pixtral-9b](https://huggingface.co/pixtral-hf/pixtral-9b)
Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.
- Parameters
hidden_size (int, optional, defaults to 1024) – Dimension of the hidden representations.
intermediate_size (int, optional, defaults to 4096) – Dimension of the MLP representations.
num_hidden_layers (int, optional, defaults to 24) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 16) – Number of attention heads in the Transformer encoder.
num_channels (int, optional, defaults to 3) – Number of input channels in the input images.
image_size (int, optional, defaults to 1024) – Max dimension of the input images.
patch_size (int, optional, defaults to 16) – Size of the image patches.
hidden_act (str, optional, defaults to “gelu”) – Activation function used in the hidden layers.
attention_dropout (float, optional, defaults to 0.0) – Dropout probability for the attention layers.
rope_theta (float, optional, defaults to 10000.0) – The base period of the RoPE embeddings.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python >>> from transformers import PixtralVisionModel, PixtralVisionConfig
>>> # Initializing a Pixtral-12B style configuration >>> config = PixtralVisionConfig()
>>> # Initializing a model (with randomly initialized weights) from the configuration >>> model = PixtralVisionModel(configuration)
>>> # Accessing the model configuration >>> configuration = model.config ```
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- model_type: str = 'pixtral'#
- class easydel.__init__.PixtralVisionModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe Pixtral Vision Model transformer.
This class implements the complete Pixtral vision model, including patch embedding via convolution and the main transformer stack.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- patch_conv#
Convolutional layer for patch embedding.
- Type
nn.Conv
- transformer#
The main transformer stack.
- Type
- property frequencies#
Cached property to compute and retrieve RoPE frequencies.
- class easydel.__init__.PyTree[source]#
Bases:
_PyTreeNodeBaseBase class for dataclasses that should act like a JAX pytree node.
- class easydel.__init__.Qwen2Config(vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = 0.0, fcm_max_ratio: float = 0.0, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, scan_layers: bool = True, rope_scaling: Optional[Mapping[str, str | float]] = None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 151936) – Vocabulary size of the Qwen-2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 22016) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional, defaults to 32) – Number of key and value heads for each attention layer in the Transformer encoder.
hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
max_position_embeddings (int, optional, defaults to 32768) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
use_sliding_window (bool, optional, defaults to False) – Whether to use a sliding window attention.
sliding_window (int, optional, defaults to 4096) – The sliding window size.
max_window_layers (int, optional, defaults to 28) – The maximum number of layers to use for the sliding window attention.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
embd_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
fcm_min_ratio (float, optional, defaults to 0.0) – The minimum ratio for Flash Attention.
fcm_max_ratio (float, optional, defaults to 0.0) – The maximum ratio for Flash Attention.
use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.
scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.
number_rep_kv (int, optional, defaults to 1) – Number of repetitions for the key and value vectors.
bits (int, optional) – The number of bits to quantize the model to.
scan_layers (bool, optional, defaults to True) – Whether to use the scan implementation for the layers.
rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.
- attach_custom_arguments(resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, attention_dropout: float = 0.0, tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = 0.0, fcm_max_ratio: float = 0.0, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, rope_theta: float = 10000.0, hidden_act: str = 'silu', scan_layers: bool = True, rope_scaling: Optional[Mapping[str, str | float]] = None, **kwargs)[source]#
The attach_custom_arguments function adds the following arguments to the Transformer class:
- Parameters
self – Refer to the current object
resid_pdrop – float: Set the dropout rate for residual connections.
embd_pdrop – float: Set the probability of dropping an embedding.
attention_dropout – float: Set the probability of dropping out the attention layer.
tie_word_embeddings – bool: Tie the word embeddings to the decoder.
gradient_checkpointing – str: Control the amount of memory used by jax.
fcm_min_ratio – float: Control the minimum ratio for Flash Attention.
fcm_max_ratio – float: Set the maximum ratio for Flash Attention.
use_scan_mlp – bool: Determine whether to use the scan_mlp function or not.
scan_mlp_chunk_size – int: Set the chunk size for scan_mlp.
number_rep_kv – int: Determine how many times the key and value vectors are repeated.
bits – tp.Optional[int]: Determine the number of bits used in the quantization.
rope_theta – float: Base value for RoPE.
hidden_act – str: Activation function name.
scan_layers – bool: Determine whether to use scan layers or not.
rope_scaling (tp.Optional[tp.Mapping[str, str | float]], optional) – RoPE scaling configuration.
**kwargs – Additional keyword arguments to attach.
- get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#
Get the partition rules for the model.
- Parameters
fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.
- Returns
The partition rules.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- static get_weight_decay_exclusions()[source]#
Returns a tuple of parameter names for which weight decay should be excluded.
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size specifically for frequency-based position embeddings.
If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for frequency encoding.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size specifically for mask-based position embeddings.
If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for mask encoding.
- Return type
int
- model_type: str = 'qwen2'#
- class easydel.__init__.Qwen2ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleQwen2 model with a Causal Language Modeling head.
This model consists of the base Qwen2 transformer (Qwen2Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core Qwen2 transformer model.
- Type
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.__init__.Qwen2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleQwen2 model with a Sequence Classification head.
This model consists of the base Qwen2 transformer (Qwen2Model) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token or a pooled representation) to the number of classes for classification.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core Qwen2 transformer model.
- Type
- score#
The linear layer for classification.
- Type
- class easydel.__init__.Qwen2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Qwen2 model transformer.
This class represents the core transformer architecture of the Qwen2 model, consisting of an embedding layer, multiple Qwen2DecoderLayer layers, and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[Qwen2DecoderLayer]
- dropout#
Dropout layer applied after embeddings.
- Type
nn.Dropout
- gradient_checkpointing#
Gradient checkpointing configuration.
- class easydel.__init__.Qwen2MoeConfig(vocab_size=151936, hidden_size=2048, intermediate_size=5632, num_hidden_layers=24, num_attention_heads=16, num_key_value_heads=16, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, tie_word_embeddings=False, qkv_bias=False, rope_theta=10000.0, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, decoder_sparse_step=1, moe_intermediate_size=1408, shared_expert_intermediate_size=5632, num_experts_per_tok=4, num_experts=60, norm_topk_prob=False, output_router_logits=False, router_aux_loss_coef=0.001, mlp_only_layers=None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 151936) – Vocabulary size of the Qwen-2 MoE model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 5632) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 24) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional, defaults to 16) – Number of key and value heads for each attention layer in the Transformer encoder.
hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
max_position_embeddings (int, optional, defaults to 32768) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
use_sliding_window (bool, optional, defaults to False) – Whether to use a sliding window attention.
sliding_window (int, optional, defaults to 4096) – The sliding window size.
max_window_layers (int, optional, defaults to 28) – The maximum number of layers to use for the sliding window attention.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
decoder_sparse_step (int, optional, defaults to 1) – The sparse step for the decoder.
moe_intermediate_size (int, optional, defaults to 1408) – The intermediate size of the MoE layer.
shared_expert_intermediate_size (int, optional, defaults to 5632) – The intermediate size of the shared expert.
num_experts_per_tok (int, optional, defaults to 4) – The number of experts per token.
num_experts (int, optional, defaults to 60) – The number of experts.
norm_topk_prob (bool, optional, defaults to False) – Whether to normalize the top-k probabilities.
output_router_logits (bool, optional, defaults to False) – Whether to output the router logits.
router_aux_loss_coef (float, optional, defaults to 0.001) – The coefficient for the router auxiliary loss.
mlp_only_layers (list of int, optional) – The layers that should only contain an MLP.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
bits (int, optional) – The number of bits to quantize the model to.
- get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#
Get the partition rules for the model.
- Parameters
fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.
- Returns
The partition rules.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- static get_weight_decay_exclusions()[source]#
Returns a tuple of parameter names for which weight decay should be excluded.
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size specifically for frequency-based position embeddings.
If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for frequency encoding.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size specifically for mask-based position embeddings.
If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for mask encoding.
- Return type
int
- model_type: str = 'qwen2_moe'#
- class easydel.__init__.Qwen2MoeForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleQwen2 MoE model with a Causal Language Modeling (CLM) head.
This class wraps the base Qwen2MoeModel and adds a linear layer (language model head) to predict the next token logits.
- config#
Configuration object for the model.
- Type
- model#
The base Qwen2 MoE model.
- Type
- lm_head#
The language model head (linear layer).
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.__init__.Qwen2MoeForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleQwen2 MoE model with a sequence classification head.
This class wraps the base Qwen2MoeModel and adds a linear layer on top to perform sequence classification tasks.
- config#
Configuration object for the model.
- Type
- model#
The base Qwen2 MoE model.
- Type
- score#
The sequence classification head (linear layer).
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.__init__.Qwen2MoeModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Qwen2 MoE transformer model.
This class implements the core transformer architecture, including embedding layers, decoder layers, and final normalization.
- config#
Configuration object for the model.
- Type
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
nn.List[Qwen2MoeDecoderLayer]
- gradient_checkpointing#
Gradient checkpointing strategy.
- Type
str
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.__init__.Qwen2VLConfig(vocab_size=152064, hidden_size=8192, intermediate_size=29568, num_hidden_layers=80, num_attention_heads=64, num_key_value_heads=8, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, tie_word_embeddings=False, rope_theta=1000000.0, use_sliding_window=False, sliding_window=4096, max_window_layers=80, attention_dropout=0.0, vision_config=None, rope_scaling=None, vision_start_token_id=151652, vision_end_token_id=151653, vision_token_id=151654, image_token_id=151655, video_token_id=151656, **kwargs)[source]#
Bases:
EasyDeLBaseConfigThis is the configuration class to store the configuration of a [Qwen2VLModel]. It is used to instantiate a Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 152064) – Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [Qwen2VLModel]
hidden_size (int, optional, defaults to 8192) – Dimension of the hidden representations.
intermediate_size (int, optional, defaults to 29568) – Dimension of the MLP representations.
num_hidden_layers (int, optional, defaults to 80) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 64) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional, defaults to 8) – This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 32.
hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) in the decoder.
max_position_embeddings (int, optional, defaults to 32768) – The maximum sequence length that this model might ever be used with.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (float, optional, defaults to 1e-05) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
tie_word_embeddings (bool, optional, defaults to False) – Whether the model’s input and output word embeddings should be tied.
rope_theta (float, optional, defaults to 1000000.0) – The base period of the RoPE embeddings.
use_sliding_window (bool, optional, defaults to False) – Whether to use sliding window attention.
sliding_window (int, optional, defaults to 4096) – Sliding window attention (SWA) window size. If not specified, will default to 4096.
max_window_layers (int, optional, defaults to 80) – The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
vision_config (tp.Dict, optional) – The config for the visual encoder initialization.
rope_scaling (tp.Dict, optional) –
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer max_position_embeddings, we recommend you to update this value accordingly. Expected contents:
- rope_type (str):
The sub-variant of RoPE to use. Can be one of [‘default’, ‘linear’, ‘dynamic’, ‘yarn’, ‘longrope’, ‘llama3’], with ‘default’ being the original RoPE implementation.
- factor (float, optional):
Used with all rope types except ‘default’. The scaling factor to apply to the RoPE embeddings. In most scaling types, a factor of x will enable the model to handle sequences of length x * original maximum pre-trained length.
- original_max_position_embeddings (int, optional):
Used with ‘dynamic’, ‘longrope’ and ‘llama3’. The original max position embeddings used during pretraining.
- attention_factor (float, optional):
Used with ‘yarn’ and ‘longrope’. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the factor field to infer the suggested value.
- beta_fast (float, optional):
Only used with ‘yarn’. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32.
- beta_slow (float, optional):
Only used with ‘yarn’. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1.
- short_factor (tp.List[float], optional):
Only used with ‘longrope’. The scaling factor to be applied to short contexts (< original_max_position_embeddings). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2
- long_factor (tp.List[float], optional):
Only used with ‘longrope’. The scaling factor to be applied to long contexts (< original_max_position_embeddings). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2
- low_freq_factor (float, optional):
Only used with ‘llama3’. Scaling factor applied to low frequency components of the RoPE
- high_freq_factor (float, optional):
Only used with ‘llama3’. Scaling factor applied to high frequency components of the RoPE
```python >>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
>>> # Initializing a Qwen2VL style configuration >>> configuration = Qwen2VLConfig()
>>> # Initializing a model from the Qwen2-VL-7B style configuration >>> model = Qwen2VLForConditionalGeneration(configuration)
>>> # Accessing the model configuration >>> configuration = model.config ```
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model parameters for use with distributed training.
These rules define how model parameters should be partitioned across multiple devices when using techniques like Fully Sharded Data Parallelism (FSDP), Sharded Parallelism (SP), and Tensor Parallelism (TP).
- Returns
- A tuple of tuples where each inner tuple contains:
A regex pattern matching parameter names
A PartitionSpec object specifying how to partition matching parameters
- Return type
tuple
- keys_to_ignore_at_inference = ['past_key_values']#
- model_type: str = 'qwen2_vl'#
- sub_configs: dict[str, 'PretrainedConfig'] = {'vision_config': <class 'easydel.__init__.modules.qwen2_vl.qwen2_vl_configuration.Qwen2VLVisionConfig'>}#
- class easydel.__init__.Qwen2VLForConditionalGeneration(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- get_static_arguments()[source]#
Returns a tuple of static arguments required by the module’s __call__ method.
Static arguments are those that don’t change across calls and can be potentially cached or handled differently by JIT compilation. This base implementation returns an empty tuple. Subclasses should override this if they have static arguments.
- Returns
A tuple containing static arguments.
- Return type
tp.Tuple
- loss_type = 'ForCausalLM'#
- prepare_inputs_for_call(image_grid_thw: Optional[Union[Array, ndarray, bool, number]] = None, video_grid_thw: Optional[Union[Array, ndarray, bool, number]] = None, image_max_grid_size: int = None, video_max_grid_size: int = None, drop_ids: bool = True, **others)[source]#
Prepares keyword arguments before passing them to the module’s __call__ method.
This base implementation simply returns the kwargs as is. Subclasses can override this to modify or add arguments as needed (e.g., for generation).
- Parameters
**kwargs – The keyword arguments intended for __call__.
- Returns
The prepared keyword arguments.
- Return type
dict
- prepare_inputs_for_generation(input_ids, max_length, past_key_values=None, attention_mask=None, inputs_embeds=None, position_ids=None, pixel_values=None, pixel_values_videos=None, image_grid_thw=None, video_grid_thw=None, **kwargs)[source]#
The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.
- Parameters
self – Access variables that belong to the class
input_ids – Pass in the input tokens
max_length – Set the length of the sequence to be generated
attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds
- Returns
A dictionary of the past_key_values, attention_mask and position ids
- class easydel.__init__.Qwen2VLModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.Qwen3Config(vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, head_dim=128, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, **kwargs)[source]#
Bases:
EasyDeLBaseConfig- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model.
- Parameters
fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.
- Returns
The partition rules.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- model_type: str = 'qwen3'#
- class easydel.__init__.Qwen3ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleQwen3 model with a Causal Language Modeling head.
This model consists of the base Qwen3 transformer (Qwen3Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core Qwen3 transformer model.
- Type
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.__init__.Qwen3ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleQwen3 model with a Sequence Classification head.
This model consists of the base Qwen3 transformer (Qwen3Model) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token or a pooled representation) to the number of classes for classification.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core Qwen3 transformer model.
- Type
- score#
The linear layer for classification.
- Type
- class easydel.__init__.Qwen3Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Qwen3 model transformer.
This class represents the core transformer architecture of the Qwen3 model, consisting of an embedding layer, multiple Qwen3DecoderLayer layers, and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[Qwen3DecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.
- class easydel.__init__.Qwen3MoeConfig(vocab_size=151936, hidden_size=2048, intermediate_size=6144, num_hidden_layers=24, num_attention_heads=32, num_key_value_heads=4, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, decoder_sparse_step=1, moe_intermediate_size=768, num_experts_per_tok=8, num_experts=128, norm_topk_prob=False, output_router_logits=False, router_aux_loss_coef=0.001, mlp_only_layers=None, **kwargs)[source]#
Bases:
EasyDeLBaseConfig- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model.
- Parameters
fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.
- Returns
The partition rules.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- model_type: str = 'qwen3_moe'#
- class easydel.__init__.Qwen3MoeForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleQwen3Moe model with a Causal Language Modeling head.
This model consists of the base Qwen3Moe transformer (Qwen3MoeModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core Qwen3Moe transformer model.
- Type
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.__init__.Qwen3MoeForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleQwen3Moe model with a Sequence Classification head.
This model consists of the base Qwen3Moe transformer (Qwen3MoeModel) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token or a pooled representation) to the number of classes for classification.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core Qwen3Moe transformer model.
- Type
- score#
The linear layer for classification.
- Type
- class easydel.__init__.Qwen3MoeModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Qwen3Moe model transformer.
This class represents the core transformer architecture of the Qwen3Moe model, consisting of an embedding layer, multiple Qwen3MoeDecoderLayer layers, and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[Qwen3MoeDecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.
- class easydel.__init__.RewardConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 5e-05, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: ~typing.Optional[int] = 1024, max_training_steps: tp.Optional[int] = None, model_name: str = 'RewardTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = False, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, disable_dropout: bool = True, dataset_num_proc: ~typing.Optional[int] = None, center_rewards_coefficient: ~typing.Optional[float] = 0.1)[source]#
Bases:
TrainingArgumentsConfiguration class for the [RewardTrainer].
- Parameters
model_name (str) – The name of the model. Defaults to “RewardTrainer”.
max_length (int, optional) – Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the limit. Defaults to 1024.
disable_dropout (bool, optional) – Whether to disable dropout in the model. Defaults to True.
dataset_num_proc (int, optional) – Number of processes to use for processing the dataset. Defaults to None.
center_rewards_coefficient (float, optional) – Coefficient to incentivize the reward model to output mean-zero rewards. Defaults to 0.1.
remove_unused_columns (bool, optional) – Whether to remove the columns that are not used by the model’s forward pass. Can be True only if the dataset is pretokenized. Defaults to False.
- center_rewards_coefficient: Optional[float] = 0.1#
- dataset_num_proc: Optional[int] = None#
- disable_dropout: bool = True#
- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- max_sequence_length: Optional[int] = 1024#
- model_name: str = 'RewardTrainer'#
- remove_unused_columns: bool = False#
- replace(**kwargs)#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.
- class easydel.__init__.RewardTrainer(arguments: RewardConfig, processing_class: Any, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, data_collator: Optional[RewardDataCollatorWithPadding] = None)[source]#
Bases:
TrainerThis trainer extends the Trainer and provides functionalities.
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
This method prepares the functions that will be used during training and evaluation. It sets up sharding for the model parameters and optimizer state, JIT-compiles the training and evaluation functions with the appropriate static arguments and sharding constraints, and also sets up the checkpoint manager.
- Returns
- An object containing:
sharded_training_step_function: The compiled training step function.
sharded_evaluation_step_function: The compiled evaluation step function.
mesh: The device mesh used for computation.
checkpoint_manager: The checkpointer for saving/loading model state.
- Return type
- create_collect_function(max_sequence_length, truncation_mode='keep_end')[source]#
Creates a collate/collect function to process batches of data for training or evaluation.
This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.
- Parameters
max_sequence_length (int) – The maximum allowed sequence length.
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.
- Returns
A function that takes a batch (list of dicts) and returns a processed dict of arrays.
- Return type
tp.Callable
- class easydel.__init__.RobertaConfig(vocab_size=50265, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act='gelu', hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=514, type_vocab_size=1, initializer_range=0.02, layer_norm_eps=1e-05, pad_token_id=1, bos_token_id=0, eos_token_id=2, position_embedding_type='absolute', use_cache=True, classifier_dropout=None, gradient_checkpointing='nothing_saveable', **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information. :param vocab_size: Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by
the
inputs_idspassed when callingRobertaModel.- Parameters
hidden_size (
int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.num_hidden_layers (
int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.num_attention_heads (
int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.intermediate_size (
int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.hidden_act (
strorfunction, optional, defaults to"gelu") – The non-linear activation function (function or string) in the encoder and pooler. If string,"gelu","relu","swish"and"gelu_new"are supported.hidden_dropout_prob (
float, optional, defaults to 0.1) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.attention_probs_dropout_prob (
float, optional, defaults to 0.1) – The dropout ratio for the attention probabilities.max_position_embeddings (
int, optional, defaults to 514) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).type_vocab_size (
int, optional, defaults to 1) – The vocabulary size of thetoken_type_idspassed when callingRobertaModel.initializer_range (
float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.layer_norm_eps (
float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.position_embedding_type (
str, optional, defaults to"absolute") – Type of position embedding. Choose one of"absolute","relative_key","relative_key_query". For positional embeddings use"absolute". For more information on"relative_key", please refer to [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). For more information on"relative_key_query", please refer to Method 4 in [Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).use_cache (
bool, optional, defaults toTrue) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant ifconfig.is_decoder=True.classifier_dropout (
float, optional) – The dropout ratio for the classification head.gradient_checkpointing (
str, optional, defaults to"nothing_saveable") – What to save during gradient checkpointing. Choose one of"nothing_saveable","first_half_saveable","full_saveable".
- attach_custom_arguments(gradient_checkpointing='nothing_saveable', **kwargs)[source]#
Attach custom arguments to the configuration.
This method allows attaching additional custom arguments to the configuration that weren’t part of the initial configuration.
- Parameters
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – What to save during gradient checkpointing. Choose one of “nothing_saveable”, “first_half_saveable”, “full_saveable”.
**kwargs – Additional custom arguments to be attached to the configuration.
- get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#
Get the partition rules for the model.
- Parameters
fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.
- Returns
The partition rules.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- model_type: str = 'roberta'#
- class easydel.__init__.RobertaForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.RobertaForMultipleChoice(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.RobertaForQuestionAnswering(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.RobertaForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.RobertaForTokenClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.SFTConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 2e-05, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'SFTTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, dataset_text_field: ~typing.Optional[str] = None, add_special_tokens: bool = False, packing: bool = False, dataset_num_proc: ~typing.Optional[int] = None, dataset_batch_size: int = 1000, dataset_kwargs: ~typing.Optional[dict[str, typing.Any]] = None, eval_packing: ~typing.Optional[bool] = None, num_of_sequences: int = 1024, chars_per_token: float = 3.6)[source]#
Bases:
TrainingArgumentsConfiguration class for the [SFTTrainer].
- Parameters
model_name (str) – The name of the model. Defaults to “SFTTrainer”.
dataset_text_field (str, optional) – Name of the text field of the dataset. If provided, the trainer will automatically create a [ConstantLengthDataset] based on dataset_text_field. Defaults to None.
packing (bool, optional) – Controls whether the [ConstantLengthDataset] packs the sequences of the dataset. Defaults to False.
learning_rate (float, optional) – Initial learning rate for [AdamW] optimizer. The default value replaces that of [~transformers.TrainingArguments]. Defaults to 2e-5.
dataset_num_proc (int, optional) – Number of processes to use for processing the dataset. Only used when packing=False. Defaults to None.
dataset_batch_size (int, optional) – Number of examples to tokenize per batch. If dataset_batch_size <= 0 or dataset_batch_size is None, tokenizes the full dataset as a single batch. Defaults to 1000.
dataset_kwargs (dict[str, Any], optional) – Dictionary of optional keyword arguments to pass when creating packed or non-packed datasets. Defaults to None.
eval_packing (bool, optional) – Whether to pack the eval dataset. If None, uses the same value as packing. Defaults to None.
num_of_sequences (int, optional) – Number of sequences to use for the [ConstantLengthDataset]. Defaults to 1024.
chars_per_token (float, optional) – Number of characters per token to use for the [ConstantLengthDataset]. See [chars_token_ratio](huggingface/trl) for more details. Defaults to 3.6.
- add_special_tokens: bool = False#
- chars_per_token: float = 3.6#
- dataset_batch_size: int = 1000#
- dataset_kwargs: Optional[dict[str, Any]] = None#
- dataset_num_proc: Optional[int] = None#
- dataset_text_field: Optional[str] = None#
- eval_packing: Optional[bool] = None#
- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- learning_rate: float = 2e-05#
- model_name: str = 'SFTTrainer'#
- num_of_sequences: int = 1024#
- packing: bool = False#
- replace(**kwargs)#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.
- class easydel.__init__.SFTTrainer(arguments: SFTConfig, processing_class: Any, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, formatting_func: Optional[Callable] = None, data_collator: Optional[DataCollatorForCompletionOnlyLM] = None)[source]#
Bases:
TrainerTrainer class for Supervised Fine-Tuning (SFT) of language models.
This trainer extends the Trainer and provides functionalities specific to supervised fine-tuning tasks.
- class easydel.__init__.SamplingParams(max_tokens: int = 16, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repetition_penalty: float = 1.0, temperature: float = 0.0, top_p: float = 1.0, top_k: int = 0, min_p: float = 0.0, suppress_tokens: list[int] = <factory>)[source]#
Bases:
objectParameters controlling the sampling process during text generation.
- max_tokens#
The maximum number of tokens to generate (excluding the prompt). Defaults to 16.
- Type
int
- presence_penalty#
Penalty applied to the logits of tokens already present in the generated sequence. Positive values discourage repetition. Defaults to 0.0.
- Type
float
- frequency_penalty#
Penalty applied to the logits of tokens based on their frequency in the generated sequence so far. Positive values discourage verbatim repetition. Defaults to 0.0.
- Type
float
- repetition_penalty#
Multiplicative penalty applied to the logits of previously seen tokens. Values > 1.0 discourage repetition, < 1.0 encourage it. Defaults to 1.0.
- Type
float
- temperature#
Controls the randomness of the sampling. Higher values (e.g., > 1.0) make the distribution flatter (more random), lower values (e.g., < 1.0) make it peakier (more deterministic). A value of 0.0 effectively becomes greedy sampling. Defaults to 0.0.
- Type
float
- top_p#
Nucleus sampling threshold. If set to a value < 1.0, only the most probable tokens with a cumulative probability exceeding top_p are considered for sampling. Defaults to 1.0 (no nucleus sampling).
- Type
float
- top_k#
Top-k sampling threshold. If set to a value > 0, only the top_k most probable tokens are considered for sampling. Defaults to 0 (no top-k sampling).
- Type
int
- min_p#
Minimum probability threshold. Filters out tokens with probability less than min_p. Defaults to 0.0 (no minimum probability filtering).
- Type
float
- suppress_tokens#
A list of token IDs that should be completely suppressed (their logits set to -inf) during generation. Defaults to an empty list.
- Type
list[int]
- frequency_penalty: float = 0.0#
- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- get_logits_processor()[source]#
Constructs a LogitsProcessorList containing the configured logits processors.
Logits processors modify the logits directly, often used for applying penalties (presence, frequency, repetition) or suppressing specific tokens.
- Returns
A LogitsProcessorList containing the enabled logits processors based on the sampling parameters.
- get_logits_warper()[source]#
Constructs a LogitsProcessorList containing the configured logits warpers.
Logits warpers modify the probability distribution derived from logits, typically used for techniques like temperature scaling, top-k, top-p, and min-p sampling.
- Returns
A LogitsProcessorList containing the enabled logits warpers based on the sampling parameters.
- max_tokens: int = 16#
- min_p: float = 0.0#
- presence_penalty: float = 0.0#
- repetition_penalty: float = 1.0#
- replace(**kwargs)#
- suppress_tokens: list[int]#
- temperature: float = 0.0#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.
- top_k: int = 0#
- top_p: float = 1.0#
- class easydel.__init__.SiglipConfig(text_config=None, vision_config=None, **kwargs)[source]#
Bases:
EasyDeLBaseConfig[SiglipConfig] is the configuration class to store the configuration of a [SiglipModel]. It is used to instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
text_config (dict, optional) – Dictionary of configuration options used to initialize [SiglipTextConfig].
vision_config (dict, optional) – Dictionary of configuration options used to initialize [SiglipVisionConfig].
kwargs (optional) – Dictionary of keyword arguments.
- classmethod from_text_vision_configs(text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs)[source]#
Instantiate a [SiglipConfig] (or a derived class) from siglip text model configuration and siglip vision model configuration.
- Returns
An instance of a configuration object
- Return type
[SiglipConfig]
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the SigLIP model parameters for use with distributed training.
These rules define how parameters should be partitioned across multiple devices when using techniques like Fully Sharded Data Parallelism (FSDP), Sharded Parallelism (SP), and Tensor Parallelism (TP). Each rule consists of a regex pattern for matching parameter names and a corresponding PartitionSpec.
- Returns
- A tuple of tuples where each inner tuple contains:
A regex pattern matching parameter names
A PartitionSpec object specifying how to partition matching parameters
- Return type
tuple
- model_type: str = 'siglip'#
The model type identifier used to determine which model configuration this represents. This is set to “siglip” to identify this as the main configuration for the SigLIP model.
- sub_configs: dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.siglip.configuration_siglip.SiglipTextConfig'>, 'vision_config': <class 'easydel.__init__.modules.siglip.configuration_siglip.SiglipVisionConfig'>}#
A dictionary that maps configuration keys to their respective configuration classes. This enables the SiglipConfig to manage both text and vision components through separate configurations while maintaining them as part of a single unified model.
- class easydel.__init__.SiglipForImageClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.SiglipModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- get_image_features(pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False) Union[Array, ndarray, bool, number][source]#
- get_text_features(input_ids: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, position_ids: Optional[Union[Array, ndarray, bool, number]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None) Union[Array, ndarray, bool, number][source]#
- class easydel.__init__.SiglipTextConfig(vocab_size=32000, hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, max_position_embeddings=64, hidden_act='gelu_pytorch_tanh', layer_norm_eps=1e-06, attention_dropout=0.0, pad_token_id=1, bos_token_id=49406, eos_token_id=49407, projection_size=None, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 32000) – Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [SiglipModel].
hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.
max_position_embeddings (int, optional, defaults to 64) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
hidden_act (str or function, optional, defaults to “gelu_pytorch_tanh”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “selu” and “gelu_new” “quick_gelu” are supported.
layer_norm_eps (float, optional, defaults to 1e-06) – The epsilon used by the layer normalization layers.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
pad_token_id (int, optional, defaults to 1) – The id of the padding token in the vocabulary.
bos_token_id (int, optional, defaults to 49406) – The id of the beginning-of-sequence token in the vocabulary.
eos_token_id (int, optional, defaults to 49407) – The id of the end-of-sequence token in the vocabulary.
projection_size (int, optional, defaults to hidden_size) – The size of the projection head.
Example:
```python >>> from transformers import SiglipTextConfig, SiglipTextModel
>>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration >>> configuration = SiglipTextConfig()
>>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration >>> model = SiglipTextModel(configuration)
>>> # Accessing the model configuration >>> configuration = model.config ```
- base_config_key: str = 'text_config'#
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- model_type: str = 'siglip_text_model'#
- class easydel.__init__.SiglipTextModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.SiglipVisionConfig(hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, num_channels=3, image_size=224, patch_size=16, hidden_act='gelu_pytorch_tanh', layer_norm_eps=1e-06, attention_dropout=0.0, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.
num_channels (int, optional, defaults to 3) – Number of channels in the input images.
image_size (int, optional, defaults to 224) – The size (resolution) of each image.
patch_size (int, optional, defaults to 16) – The size (resolution) of each patch.
hidden_act (str or function, optional, defaults to “gelu_pytorch_tanh”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “selu” and “gelu_new” “quick_gelu” are supported.
layer_norm_eps (float, optional, defaults to 1e-06) – The epsilon used by the layer normalization layers.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
- base_config_key: str = 'vision_config'#
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- model_type: str = 'siglip_vision_model'#
- class easydel.__init__.StableLmConfig(vocab_size=50304, intermediate_size=6912, hidden_size=2560, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act='silu', max_position_embeddings=4096, initializer_range=0.02, layer_norm_eps=1e-05, use_cache=True, tie_word_embeddings=False, rope_theta=10000, rope_scaling=None, use_qkv_bias=False, qk_layernorm=False, use_parallel_residual=False, hidden_dropout=0.0, attention_dropout=0.0, partial_rotary_factor=0.25, bos_token_id=0, eos_token_id=0, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 50304) – Vocabulary size of the StableLM model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [~easydel.modules.StableLmModel].
hidden_size (int, optional, defaults to 2560) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 6912) – Dimensionality of the “intermediate” (often named feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional, defaults to 32) – Number of key-value heads for each attention layer in the Transformer encoder.
hidden_act (str or tp.Callable, optional, defaults to “silu”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
max_position_embeddings (int, optional, defaults to 4096) – The maximum sequence length that this model might ever be used with.
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models).
tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (int, optional, defaults to 10000) – The theta value for the rotary position embeddings.
rope_scaling (str, optional) – The scaling to use for the rotary position embeddings.
qk_layernorm (bool, optional, defaults to False) – Whether to use layer normalization on the queries and keys in the attention layer.
use_parallel_residual (bool, optional, defaults to False) – Whether to use a parallel residual connection in the attention layer.
hidden_dropout (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
partial_rotary_factor (float, optional, defaults to 0.25) – The factor to scale the partial rotary embeddings by.
bos_token_id (int, optional, defaults to 0) – The id for the beginning of stream token.
eos_token_id (int, optional, defaults to 0) – The id for the end of stream token.
bits (int, optional) – The number of bits to quantize the model to. If None, the model is not quantized.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – What to save during gradient checkpointing. Choose one of “nothing_saveable”, “first_half_saveable”, “full_saveable”.
- get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#
Get the partition rules for the model.
- Parameters
fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.
- Returns
The partition rules.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
Returns the maximum position embedding size specifically for frequency-based position embeddings.
If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for frequency encoding.
- Return type
int
- property granted_mask_max_position_embedding: int#
Returns the maximum position embedding size specifically for mask-based position embeddings.
If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.
- Returns
The granted maximum position embedding size for mask encoding.
- Return type
int
- model_type: str = 'stablelm'#
- class easydel.__init__.StableLmForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleStableLM model with a Causal Language Modeling (CLM) head.
This class wraps the base StableLmModel and adds a linear layer (language model head) to predict the next token logits.
- config#
Configuration object for the model.
- Type
- model#
The base StableLM model.
- Type
- lm_head#
The language model head (linear layer).
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.__init__.StableLmModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base StableLM transformer model.
This class implements the core transformer architecture, including embedding layers, decoder layers, and final normalization.
- config#
Configuration object for the model.
- Type
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
nn.List[StableLmDecoderLayer]
- norm#
Final layer normalization.
- Type
nn.LayerNorm
- gradient_checkpointing#
Gradient checkpointing strategy.
- Type
str
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- property frequencies#
Cached property for precomputed rotary frequencies.
- class easydel.__init__.TaskType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,EnumEnumeration defining different model task types supported by the registry.
- CAUSAL_LM#
Causal Language Modeling (e.g., GPT-style models).
- VISION_LM#
Vision Language Modeling (models combining vision and text).
- IMAGE_TEXT_TO_TEXT#
Models that take image and text input to produce text output.
- BASE_MODULE#
Basic, potentially abstract, modules.
- BASE_VISION#
Basic vision modules.
- SEQUENCE_TO_SEQUENCE#
Sequence-to-sequence tasks (e.g., translation, summarization).
- SPEECH_SEQUENCE_TO_SEQUENCE#
Speech-to-text or other speech sequence tasks.
- ZERO_SHOT_IMAGE_CLASSIFICATION#
Image classification without task-specific training.
- SEQUENCE_CLASSIFICATION#
Classifying entire sequences (e.g., sentiment analysis).
- AUDIO_CLASSIFICATION#
Classifying audio data.
- IMAGE_CLASSIFICATION#
Classifying images.
- AUDIO_CLASSIFICATION = 'audio-classification'#
- BASE_MODULE = 'base-module'#
- BASE_VISION = 'vision-module'#
- CAUSAL_LM = 'causal-language-model'#
- IMAGE_CLASSIFICATION = 'image-classification'#
- IMAGE_TEXT_TO_TEXT = 'image-text-to-text'#
- SEQUENCE_CLASSIFICATION = 'sequence-classification'#
- SEQUENCE_TO_SEQUENCE = 'sequence-to-sequence'#
- SPEECH_SEQUENCE_TO_SEQUENCE = 'speech-sequence-to-sequence'#
- VISION_LM = 'vision-language-model'#
- ZERO_SHOT_IMAGE_CLASSIFICATION = 'zero-shot-image-classification'#
- class easydel.__init__.Trainer(arguments: tp.Optional[TrainingArguments] = None, model_state: tp.Optional[EasyDeLState] = None, model: tp.type[EasyDeLBaseModule] = None, dataset_train: tp.Optional[Dataset] = None, dataset_eval: tp.Optional[Dataset] = None, data_collator: tp.Optional[tp.Callable] = None, finetune: bool = True, checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]] = None, **deprecated_kwargs)[source]#
Bases:
BaseTrainer- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
This method prepares the functions that will be used during training and evaluation. It sets up sharding for the model parameters and optimizer state, JIT-compiles the training and evaluation functions with the appropriate static arguments and sharding constraints, and also sets up the checkpoint manager.
- Returns
- An object containing:
sharded_training_step_function: The compiled training step function.
sharded_evaluation_step_function: The compiled evaluation step function.
mesh: The device mesh used for computation.
checkpoint_manager: The checkpointer for saving/loading model state.
- Return type
- create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Creates a collate/collect function to process batches of data for training or evaluation.
This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.
- Parameters
max_sequence_length (int) – The maximum allowed sequence length.
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.
- Returns
A function that takes a batch (list of dicts) and returns a processed dict of arrays.
- Return type
tp.Callable
- eval(model_state: EasyDeLState) Iterator[dict][source]#
Evaluates the model using the provided model state.
This method iterates over the evaluation dataset, performs forward passes, calculates evaluation metrics, logs the metrics, and yields the metrics for each evaluation step.
- Parameters
model_state (EasyDeLState) – The state of the model (including parameters and configuration) to be used for evaluation.
- Yields
Iterator[dict] – An iterator yielding a dictionary of evaluation metrics for each evaluation step.
- Raises
AssertionError – If the evaluation dataloader is not set.
- train() TrainerOutput[source]#
Executes the complete training process.
This method sets up initial metrics and logging, runs the training loop, and finalizes training. It calls the training hook at the beginning and returns a TrainerOutput object at the end.
- Returns
An object containing the final training state, metrics, and any additional outputs.
- Return type
- class easydel.__init__.TrainingArguments(auto_shard_states: 'bool' = True, aux_loss_enabled: 'bool' = False, backend: 'tp.Optional[str]' = None, clip_grad: 'tp.Optional[float]' = None, custom_scheduler: 'tp.Optional[tp.Callable[[int], tp.Any]]' = None, dataloader_num_workers: 'tp.Optional[int]' = 0, dataloader_pin_memory: 'tp.Optional[bool]' = False, do_eval: 'bool' = False, do_last_save: 'bool' = True, do_train: 'bool' = True, eval_batch_size: 'tp.Optional[int]' = None, evaluation_steps: 'tp.Optional[int]' = None, extra_optimizer_kwargs: 'dict' = <factory>, frozen_parameters: 'tp.Optional[str]' = None, gradient_accumulation_steps: 'int' = 1, ids_to_pop_from_dataset: 'tp.Optional[tp.List[str]]' = <factory>, is_fine_tuning: 'bool' = True, init_tx: 'bool' = True, jax_distributed_config: 'tp.Optional[dict]' = None, learning_rate: 'float' = 5e-05, learning_rate_end: 'tp.Optional[float]' = None, log_all_workers: 'bool' = False, log_grad_norms: 'bool' = True, report_metrics: 'bool' = True, log_steps: 'int' = 10, loss_config: 'tp.Optional[LossConfig]' = None, low_mem_usage: 'bool' = True, max_evaluation_steps: 'tp.Optional[int]' = None, max_sequence_length: 'tp.Optional[int]' = 4096, max_training_steps: 'tp.Optional[int]' = None, model_name: 'str' = 'BaseTrainer', model_parameters: 'tp.Optional[dict]' = None, metrics_to_show_in_rich_pbar: 'tp.Optional[tp.List[str]]' = None, num_train_epochs: 'int' = 10, offload_dataset: 'bool' = False, offload_device_type: 'str' = 'cpu', offload_device_index: 'int' = 0, optimizer: 'AVAILABLE_OPTIMIZERS' = <EasyDeLOptimizers.ADAMW: 'adamw'>, performance_mode: 'bool' = False, pruning_module: 'tp.Any' = None, process_zero_is_admin: 'bool' = True, progress_bar_type: "tp.Literal['tqdm', 'rich', 'json']" = 'tqdm', remove_ckpt_after_load: 'bool' = False, remove_unused_columns: 'bool' = True, report_steps: 'int' = 5, save_directory: 'str' = 'EasyDeL-Checkpoints', save_optimizer_state: 'bool' = True, save_steps: 'tp.Optional[int]' = None, save_total_limit: 'tp.Optional[int]' = None, scheduler: 'AVAILABLE_SCHEDULERS' = <EasyDeLSchedulers.NONE: 'None'>, sparsify_module: 'bool' = False, sparse_module_type: 'AVAILABLE_SPARSE_MODULE_TYPES' = 'bcoo', state_apply_fn_kwarguments_to_model: 'tp.Optional[dict]' = None, step_partition_spec: 'PartitionSpec' = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: 'tp.Optional[int]' = None, shuffle_train_dataset: 'bool' = True, total_batch_size: 'int' = 32, training_time_limit: 'tp.Optional[str]' = None, train_on_inputs: 'bool' = True, truncation_mode: "tp.Literal['keep_end', 'keep_start']" = 'keep_end', tx_mu_dtype: 'tp.Optional[jnp.dtype]' = None, track_memory: 'bool' = False, use_data_collactor: 'bool' = True, use_wandb: 'bool' = True, verbose: 'bool' = True, wandb_entity: 'tp.Optional[str]' = None, warmup_steps: 'int' = 0, weight_decay: 'float' = 0.01, weight_distribution_pattern: 'str' = '.*?(layernorm|norm).*?', weight_distribution_log_steps: 'int' = 0)[source]#
Bases:
object- auto_shard_states: bool = True#
- aux_loss_enabled: bool = False#
- backend: Optional[str] = None#
- clip_grad: Optional[float] = None#
- custom_scheduler: Optional[Callable[[int], Any]] = None#
- dataloader_num_workers: Optional[int] = 0#
- dataloader_pin_memory: Optional[bool] = False#
- do_eval: bool = False#
- do_last_save: bool = True#
- do_train: bool = True#
- eval_batch_size: Optional[int] = None#
- evaluation_steps: Optional[int] = None#
- extra_optimizer_kwargs: dict#
- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- frozen_parameters: Optional[str] = None#
- get_optimizer_and_scheduler(steps: Optional[int] = None)[source]#
Returns the configured optimizer and learning rate scheduler.
- Parameters
steps (tp.Optional[int]) – The number of training steps. If not provided, uses the value from self.optimizer_kwargs.
- Returns
A tuple containing the optimizer and scheduler.
- Return type
tuple
- get_path() Path[source]#
Returns the path to the checkpoint directory.
- Returns
The path to the checkpoint directory.
- Return type
Path
- get_streaming_checkpointer()[source]#
Returns the checkpoint manager, responsible for saving model checkpoints.
- Returns
The checkpoint manager.
- Return type
- get_tensorboard()[source]#
Returns the TensorBoard SummaryWriter, used for logging metrics.
- Returns
The TensorBoard SummaryWriter.
- Return type
flax.metrics.tensorboard.SummaryWriter
- get_wandb_init()[source]#
Initializes Weights & Biases for experiment tracking if enabled.
- Returns
The WandB run object if initialized, else None.
- Return type
tp.Optional[wandb.sdk.wandb_run.Run]
- gradient_accumulation_steps: int = 1#
- ids_to_pop_from_dataset: Optional[List[str]]#
- init_tx: bool = True#
- is_fine_tuning: bool = True#
- property is_process_zero#
- jax_distributed_config: Optional[dict] = None#
- learning_rate: float = 5e-05#
- learning_rate_end: Optional[float] = None#
- log_all_workers: bool = False#
- log_grad_norms: bool = True#
- log_metrics(metrics: Any, step: int, log_as: Optional[Literal['summary', 'config']] = None)[source]#
Logs training metrics to Weights & Biases and/or TensorBoard.
- Parameters
metrics (tp.Dict[str, tp.Union[float, tp.List, tp.Tuple, np.ndarray, 'jnp.ndarray', 'torch.Tensor']]) – A dictionary where keys are metric names and values are metric values.
step (int) – The current training step or iteration.
- log_steps: int = 10#
- loss_config: Optional[LossConfig] = None#
- low_mem_usage: bool = True#
- max_evaluation_steps: Optional[int] = None#
- max_sequence_length: Optional[int] = 4096#
- max_training_steps: Optional[int] = None#
- metrics_to_show_in_rich_pbar: Optional[List[str]] = None#
- model_name: str = 'BaseTrainer'#
- model_parameters: Optional[dict] = None#
- num_train_epochs: int = 10#
- offload_dataset: bool = False#
- property offload_device#
- offload_device_index: int = 0#
- offload_device_type: str = 'cpu'#
- optimizer: Literal['adafactor', 'lion', 'adamw', 'rmsprop'] = 'adamw'#
- performance_mode: bool = False#
- process_zero_is_admin: bool = True#
- progress_bar_type: Literal['tqdm', 'rich', 'json'] = 'tqdm'#
- pruning_module: Any = None#
- remove_ckpt_after_load: bool = False#
- remove_unused_columns: bool = True#
- replace(**kwargs)#
- report_metrics: bool = True#
- report_steps: int = 5#
- save_directory: str = 'EasyDeL-Checkpoints'#
- save_optimizer_state: bool = True#
- save_steps: Optional[int] = None#
- save_total_limit: Optional[int] = None#
- scheduler: Literal['linear', 'cosine', 'none'] = 'None'#
- shuffle_train_dataset: bool = True#
- sparse_module_type: Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo'#
- sparsify_module: bool = False#
- state_apply_fn_kwarguments_to_model: Optional[dict] = None#
- step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp')#
- step_start_point: Optional[int] = None#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.
- total_batch_size: int = 32#
- track_memory: bool = False#
- train_on_inputs: bool = True#
- training_time_limit: Optional[str] = None#
- property training_time_seconds: int#
- truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end'#
- use_data_collactor: bool = True#
- use_wandb: bool = True#
- verbose: bool = True#
- wandb_entity: Optional[str] = None#
- warmup_steps: int = 0#
- weight_decay: float = 0.01#
- weight_distribution_log_steps: int = 0#
- weight_distribution_pattern: str = '.*?(layernorm|norm).*?'#
- class easydel.__init__.WhisperConfig(vocab_size=51865, num_mel_bins=80, encoder_layers=4, encoder_attention_heads=6, decoder_layers=4, decoder_attention_heads=6, decoder_ffn_dim=1536, encoder_ffn_dim=1536, encoder_layerdrop=0.0, decoder_layerdrop=0.0, decoder_start_token_id=50257, use_cache=True, is_encoder_decoder=True, activation_function='gelu', d_model=384, dropout=0.0, attention_dropout=0.0, activation_dropout=0.0, init_std=0.02, scale_embedding=False, max_source_positions=1500, max_target_positions=448, pad_token_id=50256, bos_token_id=50256, eos_token_id=50256, suppress_tokens=None, begin_suppress_tokens=[220, 50256], use_weighted_layer_sum=False, classifier_proj_size=256, apply_spec_augment=False, mask_time_prob=0.05, mask_time_length=10, mask_time_min_masks=2, mask_feature_prob=0.0, mask_feature_length=10, mask_feature_min_masks=0, median_filter_width=7, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 51865) – Vocabulary size of the Whisper model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [~easydel.modules.WhisperModel].
num_mel_bins (int, optional, defaults to 80) – Number of mel bins used by the feature extractor.
encoder_layers (int, optional, defaults to 6) – Number of encoder layers.
encoder_attention_heads (int, optional, defaults to 4) – Number of attention heads for each attention layer in the Transformer encoder.
decoder_layers (int, optional, defaults to 6) – Number of decoder layers.
decoder_attention_heads (int, optional, defaults to 4) – Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (int, optional, defaults to 1536) – Dimensionality of the decoder feed-forward network (FFN) layer.
encoder_ffn_dim (int, optional, defaults to 1536) – Dimensionality of the encoder feed-forward network (FFN) layer.
encoder_layerdrop (float, optional, defaults to 0.0) – The LayerDrop probability for the encoder. See the [LayerDrop paper](https://arxiv.org/abs/1909.11556) for more details.
decoder_layerdrop (float, optional, defaults to 0.0) – The LayerDrop probability for the decoder. See the [LayerDrop paper](https://arxiv.org/abs/1909.11556) for more details.
d_model (int, optional, defaults to 256) – Dimensionality of the layers and the pooler layer.
activation_function (str, optional, defaults to “gelu”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “silu” and “gelu_new” are supported.
dropout (float, optional, defaults to 0.1) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
activation_dropout (float, optional, defaults to 0.0) – The dropout ratio for activations inside the fully connected layer.
init_std (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
scale_embedding (bool, optional, defaults to False) – Scale embeddings by dividing by sqrt(d_model).
max_source_positions (int, optional, defaults to 1500) – The maximum sequence length allowed for the source text input to the model. tp.Any longer inputs will be truncated.
max_target_positions (int, optional, defaults to 448) – The maximum sequence length allowed for the target text input to the model. tp.Any longer inputs will be truncated.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models).
apply_spec_augment (bool, optional, defaults to False) – Whether to apply SpecAugment data augmentation.
mask_time_prob (float, optional, defaults to 0.05) – Propability of each feature vector along the time axis to be chosen as the start of the vector span to be masked. Approximately mask_time_prob * sequence_length // mask_time_length feature vectors will be masked along the time axis. This is only relevant if apply_spec_augment is set to True.
mask_time_length (int, optional, defaults to 10) – Length of vector span along the time axis.
mask_time_min_masks (int, optional, defaults to 2) – The minimum number of masks of length mask_feature_length generated along the time axis, each time mask, the mask will be filled with floats sampled in (random_lower_bound, random_upper_bound).
mask_feature_prob (float, optional, defaults to 0.0) – Propability of each feature vector along the feature axis to be chosen as the start of the vector span to be masked. Approximately mask_time_prob * hidden_size // mask_feature_length feature vectors will be masked along the time axis. This is only relevant if apply_spec_augment is set to True.
mask_feature_length (int, optional, defaults to 10) – Length of vector span along the feature axis.
mask_feature_min_masks (int, optional, defaults to 0) – The minimum number of masks of length mask_feature_length generated along the feature axis, each time mask, the mask will be filled with floats sampled in (random_lower_bound, random_upper_bound).
median_filter_width (int, optional, defaults to 7) – The width of the median filter applied to the mask.
bits (int, optional) – The number of bits to quantize the model to. If None, the model is not quantized.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – What to save during gradient checkpointing. Choose one of “nothing_saveable”, “first_half_saveable”, “full_saveable”.
- attribute_map: dict[str, str] = {'hidden_size': 'd_model', 'num_attention_heads': 'encoder_attention_heads'}#
- get_partition_rules(*args, **kwargs)[source]#
Returns the partition rules for the Whisper model. Arguments are ignored.
- Returns
Partition rules.
- Return type
tuple
- model_type: str = 'whisper'#
- class easydel.__init__.WhisperForAudioClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.WhisperForConditionalGeneration(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#
Computes the loss for the model given a batch of inputs and labels.
This method performs a forward pass using the provided batch arguments, then calculates the loss using the determined loss_function. It handles potential label inference (e.g., using input_ids as labels for Causal LM) and default loss configurations.
- Parameters
labels (tp.Optional[chex.Array], optional) – The target labels. If None and the task is Causal LM, input_ids from the batch might be used. Defaults to None.
loss_config (tp.Optional[LossConfig], optional) – Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.
loss_kwargs (tp.Optional[tp.Dict], optional) – Additional keyword arguments to pass directly to the loss function. Defaults to None.
**batch – Keyword arguments representing the input batch (e.g., input_ids, attention_mask).
- Returns
- A tuple containing:
The model’s output ( Pytree typically including logits, hidden states etc.)
A LossMetrics object containing the calculated loss and potentially other metrics.
- Return type
tp.Tuple[tp.Any, LossMetrics]
- Raises
AssertionError – If labels are required for the loss function but are not provided or inferred.
AssertionError – If sequence classification loss is used without num_labels in the config.
- decode(decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[Array] = None, decoder_attention_mask: Optional[Array] = None, decoder_position_ids: Optional[Array] = None, past_key_values: Optional[TransformerCache] = None, cache_metadata: Optional[TransformerMetadata] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None)[source]#
- encode(input_features: Array, attention_mask: Optional[Array] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs)[source]#
- generate(input_features, generation_config=None, logits_processor=None, return_timestamps=None, task=None, language=None, is_multilingual=None, **kwargs)[source]#
Generates sequences of token ids for models with a language modeling head.
- Parameters
input_ids (chex.Array of shape (batch_size, sequence_length)) – The sequence used as a prompt for the generation.
generation_config (~generation.GenerationConfig, optional) – The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them. If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [~generation.GenerationConfig]’s default values, whose documentation should be checked to parameterize generation.
trace (bool, optional, defaults to True) – Whether to trace generation. Setting trace=False should only be used for debugging and will lead to a considerably slower runtime.
logits_processor (`LogitsProcessorList `, optional) – Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.
kwargs (tp.Dict[str, Any], optional) – Ad hoc parametrization of generate_config and/or additional model-specific kwargs that will be forwarded to the forward function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with decoder_.
- Returns
[~utils.ModelOutput].
- loss_type = 'ForCausalLM'#
- prepare_inputs_for_generation(decoder_input_ids, max_length, attention_mask: Optional[Array] = None, decoder_attention_mask: Optional[Array] = None, encoder_outputs=None, **kwargs)[source]#
The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.
- Parameters
self – Access variables that belong to the class
input_ids – Pass in the input tokens
max_length – Set the length of the sequence to be generated
attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds
- Returns
A dictionary of the past_key_values, attention_mask and position ids
- class easydel.__init__.WhisperTimeStampLogitsProcessor(generate_config, model_config, decoder_input_length)[source]#
Bases:
LogitsProcessorA specialized [LogitsProcessor] tailored for handling timestamp tokens during generation with Whisper-style models used for Automatic Speech Recognition (ASR).
It enforces several constraints specific to timestamp prediction: 1. Suppresses `<|notimestamps|>`: Prevents the model from predicting the token
that indicates the absence of timestamps.
Alternating Tokens: Enforces that text tokens and timestamp tokens generally alternate. If the last generated token was a timestamp, it biases against predicting another timestamp immediately after (unless it’s the very beginning or certain edge cases).
Initial Timestamp Limit: Restricts the maximum value of the first timestamp token predicted using max_initial_timestamp_index.
Timestamp Probability Check: If the total probability mass assigned to all valid timestamp tokens is higher than the probability of the single most likely non-timestamp token, it forces the model to sample a timestamp token by suppressing all non-timestamp tokens.
Note
This processor assumes the existence of specific token IDs related to timestamps (e.g., eos_token_id, no_timestamps_token_id, timestamp_begin) which are typically defined in the model’s generation configuration.
- Parameters
generate_config – Configuration object containing Whisper-specific generation parameters like eos_token_id, no_timestamps_token_id, is_multilingual, max_initial_timestamp_index.
model_config – The model’s configuration (used for vocab_size as a fallback).
decoder_input_length – The length of the initial input sequence provided to the decoder (e.g., the prompt length).
- class easydel.__init__.Xerxes2Config(vocab_size: int = 256128, hidden_size: int = 4096, intermediate_size: int = 16384, num_hidden_layers: int = 32, num_attention_heads: int = 32, max_position_embeddings: int = 16384, initializer_range: float = 0.02, rms_norm_eps: float = 1e-06, use_cache: bool = True, pad_token_id: int = 0, eos_token_id: int = 1, bos_token_id: int = 2, tie_word_embeddings: bool = False, rope_theta: float = 10000.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, scan_layers: bool = False, q_lora_dim: Optional[int] = 1536, kv_lora_dim: int = 512, qk_rope_head_dim: int = 64, qk_nope_head_dim: int = 128, vhead_dim: int = 128, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 256128) – Vocabulary size of the xerxes model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 16384) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional, defaults to 16) – Number of key and value heads for each attention layer in the Transformer encoder.
head_dim (int, optional, defaults to 256) – Dimensionality of the attention head.
max_position_embeddings (int, optional, defaults to 6144) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.
eos_token_id (int, optional, defaults to 1) – The index of the end of sequence token in the vocabulary.
bos_token_id (int, optional, defaults to 2) – The index of the beginning of sequence token in the vocabulary.
tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
softmax_scale (float, optional, defaults to 14.9666295471) – softmax scale for attention module.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
bits (int, optional) – The number of bits to quantize the model to.
scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation of the layers.
- attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
The attach_custom_arguments function adds the following arguments to the Transformer class:
- Parameters
self – Refer to the current object
gradient_checkpointing – str: Control the amount of memory used by jax
bits – tp.Optional[int]: Determine the number of bits used in the quantization
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model.
- Parameters
fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.
- Returns
The partition rules.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
- property granted_mask_max_position_embedding: int#
- model_type: str = 'xerxes2'#
- class easydel.__init__.Xerxes2ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.Xerxes2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.XerxesConfig(vocab_size=256128, hidden_size=4096, intermediate_size=16384, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, head_dim=144, max_position_embeddings=16384, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, num_local_experts: int = 4, xe_moe: bool = True, num_experts_per_tok: int = 2, tie_word_embeddings=False, rope_theta=10000.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, scan_layers: bool = False, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 256128) – Vocabulary size of the xerxes model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.
intermediate_size (int, optional, defaults to 16384) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (int, optional, defaults to 16) – Number of key and value heads for each attention layer in the Transformer encoder.
head_dim (int, optional, defaults to 256) – Dimensionality of the attention head.
max_position_embeddings (int, optional, defaults to 6144) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).
initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.
eos_token_id (int, optional, defaults to 1) – The index of the end of sequence token in the vocabulary.
bos_token_id (int, optional, defaults to 2) – The index of the beginning of sequence token in the vocabulary.
tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.
rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.
softmax_scale (float, optional, defaults to 14.9666295471) – softmax scale for attention module.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
bits (int, optional) – The number of bits to quantize the model to.
scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation of the layers.
- attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
The attach_custom_arguments function adds the following arguments to the Transformer class:
- Parameters
self – Refer to the current object
gradient_checkpointing – str: Control the amount of memory used by jax
bits – tp.Optional[int]: Determine the number of bits used in the quantization
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model.
- Parameters
fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.
- Returns
The partition rules.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
- property granted_mask_max_position_embedding: int#
- model_type: str = 'xerxes'#
- class easydel.__init__.XerxesForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.__init__.XerxesModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- easydel.__init__.auto_pytree(cls=None, meta_fields: Optional[Tuple[str, ...]] = None, json_serializable: bool = True, frozen: bool = False)[source]#
Register a class as a JAX PyTree with performance optimizations.
- easydel.__init__.get_modules_by_type(model_type: str, task_type: TaskType) Tuple[Type[EasyDeLBaseConfig], Union[Type[EasyDeLBaseModule], Any]][source]#
- The get_modules_by_type function is a helper function that returns the following:
The config class for the model type specified (e.g., LlamaConfig, FalconConfig)
The EasyDeL Model class for the model type specified (e.g., FlaxLlamaForCausalLM, FalconForCausalLM)
- easydel.__init__.module_to_huggingface_model(module: ~typing.Any, config: ~typing.Any, base_huggingface_module: ~typing.Any, base_huggingface_module_kwarguments: ~typing.Optional[~typing.Dict] = None, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>, use_meta_torch: bool = True, **kw)[source]#
- easydel.__init__.module_to_torch(module: ~typing.Any, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>)[source]#
- easydel.__init__.pack_sequences(dataset: Any, max_length: int = 512, pad_token_id: int = 0, reset_position_ids: bool = False, num_proc: Optional[int] = None)[source]#
Pack sequences together with their attention masks and position IDs
# With continuous position IDs packed_dataset = pack_sequences(
dataset, max_length=512, pad_token_id=0, reset_position_ids=False
)
# With reset position IDs for each sequence packed_dataset = pack_sequences(
dataset, max_length=512, pad_token_id=0, reset_position_ids=True
)
# Example output format for a packed sequence with two sequences: # reset_position_ids=False: {
‘input_ids’: [seq1_tokens + [PAD] + seq2_tokens + [PAD] + padding], ‘attention_mask’: [1,1,1,0,1,1,1,0,0,0], ‘position_ids’: [0,1,2,3,4,5,6,7,0,0]
}
# reset_position_ids=True: {
‘input_ids’: [seq1_tokens + [PAD] + seq2_tokens + [PAD] + padding], ‘attention_mask’: [1,1,1,0,1,1,1,0,0,0], ‘position_ids’: [0,1,2,0,0,1,2,0,0,0]
}
- Parameters
dataset – Dataset containing ‘input_ids’ and ‘attention_mask’
max_length – Maximum length of packed sequence
pad_token_id – Token ID used for padding
reset_position_ids – If True, reset position IDs for each sequence in the pack
- Returns
Dataset with packed sequences, attention masks, and position IDs
- Return type
packed_dataset
- easydel.__init__.register_config(config_type: str, config_field: ConfigType = ConfigType.MODULE_CONFIG) callable#
Decorator factory to register a configuration class.
- Parameters
config_type (str) – A unique string identifier for this configuration class (e.g., “llama”).
config_field (ConfigType) – The category under which to register the config. Defaults to ConfigType.MODULE_CONFIG.
- Returns
- A decorator that takes the configuration class, registers it,
and enhances its string representation.
- Return type
callable
- easydel.__init__.register_module(task_type: TaskType, config: EasyDeLBaseConfig, model_type: str, embedding_layer_names: Optional[List[str]] = None, layernorm_names: Optional[List[str]] = None) callable#
Decorator factory to register an EasyDeL module class for a specific task.
- Parameters
task_type (TaskType) – The task the module is designed for (e.g., TaskType.CAUSAL_LM).
config (EasyDeLBaseConfig) – The configuration class associated with this module.
model_type (str) – A unique string identifier for this model implementation (e.g., “llama”).
embedding_layer_names (tp.Optional[tp.List[str]]) – Optional list of embedding layer names. Defaults to None.
layernorm_names (tp.Optional[tp.List[str]]) – Optional list of LayerNorm layer names. Defaults to None.
- Returns
- A decorator that takes the module class, registers it with its metadata,
and sets internal _model_task and _model_type attributes on the class.
- Return type
callable
- easydel.__init__.torch_dict_to_easydel_params(state_dict: ~typing.Dict[str, ~typing.Any], *, device: ~typing.Optional[~jaxlib.xla_extension.Device] = None, embedding_layer_names: ~typing.Optional[~typing.List[str]] = None, layernorm_names: ~typing.Optional[~typing.List[str]] = None, shard_fns: ~typing.Optional[~typing.Mapping[tuple, ~typing.Callable]] = None, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>, verbose: bool = True, callback: ~typing.Optional[~typing.Callable[[~jax.Array, tuple], ~jax.Array]] = None, remove_state_dict: bool = False, lm_head_name: ~typing.Optional[str] = None, uses_tie_word_embedding: bool = False, **kwargs) Dict[str, Any][source]#
Convert PyTorch state dict to EasyDel parameter format.
- Parameters
state_dict – PyTorch state dictionary
device – JAX device to use
embedding_layer_names – Names of embedding layers
layernorm_names – Names of layer normalization layers
shard_fns – tp.Mapping of parameter names to sharding functions
block_size – Size of processing blocks
params_pattern_selection – Regex pattern for parameter selection
dtype – Target dtype for parameters
verbose – Whether to show progress bar callback: callback for tensors after they are converted to a jax array.
remove_state_dict – Whether to delete state_dict after conversion
lm_head_name – Name of language model head
uses_tie_word_embedding – Whether model uses tied embeddings
**kwargs – Additional arguments
- Returns
Dictionary of converted parameters in EasyDel format
- class easydel.__init__.vInference(model: None, processor_class: None, generation_config: Optional[vInferenceConfig] = None, seed: Optional[int] = None, input_partition_spec: Optional[PartitionSpec] = None, max_new_tokens: int = 512, inference_name: Optional[str] = None)[source]#
Bases:
objectClass for performing text generation using a pre-trained language graphdef in EasyDeL.
This class handles the generation process, including initialization, precompilation, and generating text in streaming chunks.
- property SEQUENCE_DIM_MAPPING#
- execute_decode(state: SampleState, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, compile_config: Optional[vInferencePreCompileConfig] = None, sampling_params: Optional[SamplingParams] = None, func: Optional[Callable[[Any], SampleState]]) SampleState[source]#
- execute_prefill(state: SampleState, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, compile_config: Optional[vInferencePreCompileConfig] = None, sampling_params: Optional[SamplingParams] = None, func: Optional[Callable[[Any], SampleState]]) SampleState[source]#
Executes a single generation step with performance monitoring.
- generate(input_ids: Array, attention_mask: Optional[Array] = None, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, sampling_params: Optional[SamplingParams] = None, **model_kwargs) Generator[Union[SampleState, Any], SampleState, SampleState][source]#
Generates text in streaming chunks with comprehensive input adjustment.
- Parameters
input_ids – Input token IDs as a JAX array
attention_mask – Optional attention mask for the input
graphstate (nn.GraphState, optional) – in case that you want to update model state for generation.
graphother (nn.GraphState, optional) – in case that you want to update model ostate for generation.
**model_kwargs – Additional model-specific keyword arguments
- Returns
Generator yielding SampleState objects containing generation results and metrics
- property inference_name#
- property metrics#
- property model#
- property model_prefill_length: int#
Calculate the maximum length available for input prefill by subtracting the maximum new tokens from the model’s maximum sequence length.
- Returns
The maximum length available for input prefill
- Return type
int
- Raises
ValueError – If no maximum sequence length configuration is found
- precompile(config: vInferencePreCompileConfig)[source]#
Precompiles the generation functions for a given batch size and input length.
This function checks if the generation functions have already been compiled for the given configuration. If not, it compiles them asynchronously and stores them in a cache.
- Returns
True if precompilation was successful, False otherwise.
- Return type
bool
- property tokenizer#
- class easydel.__init__.vInferenceApiServer(inference_map: Union[Dict[str, Any], Any] = None, inference_init_call: Optional[Callable[[], Any]] = None, max_workers: int = 10)[source]#
Bases:
objectFastAPI server for serving vInference instances.
This server provides endpoints mimicking the OpenAI API structure for chat completions, liveness/readiness checks, token counting, and listing available models. It handles both streaming and non-streaming requests asynchronously using a thread pool.
- async chat_completions(request: ChatCompletionRequest)[source]#
Handles chat completion requests (POST /v1/chat/completions).
Validates the request, retrieves the appropriate vInference model, tokenizes the input, and delegates to streaming or non-streaming handlers.
- Parameters
request (ChatCompletionRequest) – The incoming request data.
- Returns
- The generated response, either
a complete JSON object or a streaming event-stream.
- Return type
Union[JSONResponse, StreamingResponse]
- async completions(request: CompletionRequest)[source]#
Handles completion requests (POST /v1/completions).
Processes the prompt for completion and returns generated text.
- Parameters
request (CompletionRequest) – The incoming request data.
- Returns
The generated response.
- Return type
Union[JSONResponse, StreamingResponse]
- async count_tokens(request: CountTokenRequest)[source]#
Token counting endpoint (POST /v1/count_tokens).
- fire(host='0.0.0.0', port=11556, metrics_port: Optional[int] = None, log_level='info', ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None)[source]#
Starts the uvicorn server to run the FastAPI application.
- Parameters
host (str) – The host address to bind to. Defaults to “0.0.0.0”.
port (int) – The port to listen on. Defaults to 11556.
metrics_port (tp.Optional[int]) – The port for the Prometheus metrics server. If None, defaults to port + 1. Set to -1 to disable.
log_level (str) – The logging level for uvicorn. Defaults to “info”.
ssl_keyfile (tp.Optional[str]) – Path to the SSL key file for HTTPS.
ssl_certfile (tp.Optional[str]) – Path to the SSL certificate file for HTTPS.
- class easydel.__init__.vInferenceConfig(max_new_tokens: int = 64, streaming_chunks: int = 16, num_return_sequences: Optional[Union[int, Dict[int, int]]] = 1, pad_token_id: Optional[int] = None, bos_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, partition_rules: Optional[Tuple[Tuple[str, Any]]] = None, partition_axis: Optional[PartitionAxis] = None, _loop_rows: Optional[int] = None, sampling_params: Optional[SamplingParams] = None)[source]#
Bases:
objectConfiguration class for the vInference engine, controlling the overall generation process.
This class holds parameters that define how the generation loop behaves, including length constraints, token control, sharding strategies, and sampling settings.
- max_new_tokens#
The maximum number of new tokens to generate, excluding the initial prompt tokens. Defaults to 64.
- Type
int
- streaming_chunks#
The number of generation steps to compile and execute together as a single unit. Larger chunks can improve performance on TPUs by reducing compilation overhead and kernel launch times, but may increase memory usage. Defaults to 16.
- Type
int
- num_return_sequences#
The number of sequences to generate and return. Can be: - An integer: Generate this many sequences for all inputs. - A dictionary mapping precompile hash (from vInferencePreCompileConfig)
to an integer: Generate a specific number of sequences based on the compilation configuration. Defaults to 1.
- Type
Optional[Union[int, Dict[int, int]]]
- pad_token_id#
The token ID used for padding sequences. If None, the model’s default pad token ID might be used, or padding might not be applied.
- Type
Optional[int]
- bos_token_id#
The token ID representing the beginning-of-sequence. May be used implicitly by the model or generation logic.
- Type
Optional[int]
- eos_token_id#
The token ID(s) representing the end-of-sequence. Generation stops for a sequence when one of these tokens is sampled. Can be a single integer or a list/tuple of integers.
- Type
Optional[Union[int, List[int]]]
- partition_rules#
A tuple of custom sharding rules (regex pattern, PartitionSpec) to apply to the model’s parameters and intermediate states (like attention cache). If None, default rules based on partition_axis are generated. Example: ((“.*kernel.*”, PartitionSpec(“fsdp”, None)), …)
- Type
Optional[Tuple[Tuple[str, Any]]]
- partition_axis#
A PartitionAxis object defining the logical names for sharding axes (e.g., ‘batch’, ‘sequence’, ‘head’). Required if partition_rules is None, used to generate default sharding rules.
- Type
Optional[eformer.escale.partition.constraints.PartitionAxis]
- _loop_rows#
(Internal) The calculated number of iterations needed in the generation loop based on max_new_tokens and streaming_chunks. Automatically computed in __post_init__.
- Type
Optional[int]
- sampling_params#
A SamplingParams object containing parameters for the sampling process itself (e.g., temperature, top_k, top_p, repetition penalty). If None, a default SamplingParams instance with max_tokens set to max_new_tokens is created in __post_init__.
- Type
Optional[easydel.__init__.inference.utilities.SamplingParams]
- bos_token_id: Optional[int] = None#
- eos_token_id: Optional[Union[int, List[int]]] = None#
- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- get_partition_rules(runtime_config: Optional[vInferencePreCompileConfig] = None) Tuple[Tuple[str, Any], ...][source]#
Generates or retrieves the sharding partition rules for the vInference engine.
If self.partition_rules is already set (custom rules provided), it returns them directly.
Otherwise, it constructs a default set of partition rules based on the axis names defined in self.partition_axis. These default rules aim to provide sensible sharding for common model components: - Input sequences (sequences, running_token) are sharded along batch and sequence axes. - Attention masks and position IDs are sharded similarly. - Past key-value states (attention cache), including common quantized formats
(8-bit, NF4), are sharded across batch, key sequence, head, and attention dimension axes.
Any parameters/states not matching the specific rules are replicated by default (.*).
- Parameters
runtime_config – An optional vInferencePreCompileConfig. Currently unused in the default rule generation but available for potential customization in subclasses or future versions.
- Returns
A regex pattern (string) matching parameter or state names.
A jax.sharding.PartitionSpec defining how the matched items should be sharded.
- Return type
A tuple of partition rules. Each rule is a tuple containing
- Raises
AssertionError – If self.partition_rules is None and self.partition_axis is also None, as axis names are required to generate default rules.
- max_new_tokens: int = 64#
- num_return_sequences: Optional[Union[int, Dict[int, int]]] = 1#
- pad_token_id: Optional[int] = None#
- partition_axis: Optional[PartitionAxis] = None#
- partition_rules: Optional[Tuple[Tuple[str, Any]]] = None#
- replace(**kwargs)#
- sampling_params: Optional[SamplingParams] = None#
- streaming_chunks: int = 16#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.
- class easydel.__init__.vInferencePreCompileConfig(batch_size: Union[int, List[int]] = 1, prefill_length: Optional[Union[int, List[int]]] = None, vision_included: Union[bool, List[bool]] = False, vision_batch_size: Optional[Union[int, List[int]]] = None, vision_channels: Optional[Union[int, List[int]]] = None, vision_height: Optional[Union[int, List[int]]] = None, vision_width: Optional[Union[int, List[int]]] = None, required_props: Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]] = None)[source]#
Bases:
objectConfiguration class for pre-compiling vInference functions.
This class holds parameters that define the shape and properties of inputs expected by the vInference engine during pre-compilation. It allows specifying different configurations, potentially in lists, to compile for multiple scenarios.
- batch_size#
Batch size or list of batch sizes for text generation.
- Type
Union[int, List[int]]
- prefill_length#
Prefill sequence length or list of lengths. If None, it might be inferred or not used depending on the context.
- Type
Optional[Union[int, List[int]]]
- vision_included#
Whether vision inputs are included in the model.
- Type
Union[bool, List[bool]]
- vision_batch_size#
Batch size for vision inputs. Only relevant if vision_included is True.
- Type
Optional[Union[int, List[int]]]
- vision_channels#
Number of channels for vision inputs. Only relevant if vision_included is True.
- Type
Optional[Union[int, List[int]]]
- vision_height#
Height of vision inputs. Only relevant if vision_included is True.
- Type
Optional[Union[int, List[int]]]
- vision_width#
Width of vision inputs. Only relevant if vision_included is True.
- Type
Optional[Union[int, List[int]]]
- required_props#
Optional dictionary or list of dictionaries specifying required properties for advanced configuration (e.g., specific model arguments).
- Type
Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]]
- batch_size: Union[int, List[int]] = 1#
- extract() dict[source]#
Converts the configuration instance into a dictionary.
This method is useful for serialization or easily accessing all configuration values.
- Returns
A dictionary representation of the vInferencePreCompileConfig instance.
- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- get_default_hash() int[source]#
Generates a unique integer hash representing the configuration.
This hash is calculated based on the string representation of all configuration attributes, ensuring that identical configurations produce the same hash. This is crucial for caching compiled functions based on their configuration.
- Returns
An integer hash value representing the configuration.
- get_standalones() List[vInferencePreCompileConfig][source]#
Generates a list of standalone configurations from a potentially multi-value config.
If any attribute in the current configuration is a list (indicating multiple scenarios), this method expands the configuration into multiple individual vInferencePreCompileConfig instances. Each resulting instance represents a single, specific compilation scenario.
If an attribute’s list is shorter than the longest list among all attributes, its last element is repeated to ensure all generated configurations have values for all attributes.
If the original configuration is already standalone (no list attributes), this method returns a list containing only the original instance.
- Returns
A list of vInferencePreCompileConfig instances, each representing a single, standalone compilation scenario.
- prefill_length: Optional[Union[int, List[int]]] = None#
- replace(**kwargs)#
- required_props: Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]] = None#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.
- vision_batch_size: Optional[Union[int, List[int]]] = None#
- vision_channels: Optional[Union[int, List[int]]] = None#
- vision_height: Optional[Union[int, List[int]]] = None#
- vision_included: Union[bool, List[bool]] = False#
- vision_width: Optional[Union[int, List[int]]] = None#
- class easydel.__init__.vWhisperInference(model: ~typing.Any, tokenizer: ~typing.Any, processor: ~typing.Any, inference_config: ~typing.Optional[~easydel.__init__.inference.vwhisper.config.vWhisperInferenceConfig] = None, dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>)[source]#
Bases:
objectWhisper inference pipeline for performing speech-to-text transcription or translation.
- Parameters
model (WhisperForConditionalGeneration) – The fine-tuned Whisper model to use for inference.
tokenizer (WhisperTokenizer) – Tokenizer for Whisper.
processor (WhisperProcessor) – Processor for Whisper.
inference_config (vWhisperInferenceConfig, optional) – Inference configuration.
dtype (jax.typing.DTypeLike, optional, defaults to jnp.float32) – Data type for computations.
- generate(audio_input: Union[str, bytes, ndarray, Dict[str, Union[ndarray, int]]], chunk_length_s: float = 30.0, stride_length_s: Optional[Union[float, list[float]]] = None, batch_size: Optional[int] = None, language: Optional[str] = None, task: Optional[str] = None, return_timestamps: Optional[bool] = None)[source]#
Transcribe or translate audio input.
- Parameters
audio_input (tp.Union[str, bytes, np.ndarray, tp.Dict[str, tp.Union[np.ndarray, int]]]) – Input audio. Can be a local file path, URL, bytes, numpy array, or a dictionary containing the array and sampling rate.
chunk_length_s (float, optional, defaults to 30.0) – Length of audio chunks in seconds.
stride_length_s (float or list[float], optional) – Stride length for chunking audio, in seconds. Defaults to chunk_length_s / 6.
batch_size (int, optional) – Batch size for processing. Defaults to the batch_size in inference_config.
language (str, optional) – Language of the input audio. Defaults to the language in inference_config.
task (str, optional) – Task to perform (e.g., “transcribe”, “translate”). Defaults to the task in inference_config.
return_timestamps (bool, optional) – Whether to return timestamps with the transcription. Defaults to the return_timestamps in inference_config.
- Returns
A dictionary containing the transcribed text (“text”) and optionally other information like timestamps or detected language.
- Return type
dict
- class easydel.__init__.vWhisperInferenceConfig(batch_size: Optional[int] = 1, max_length: Optional[int] = None, generation_config: Optional[Any] = None)[source]#
Bases:
objectConfiguration class for Whisper inference.
- Parameters
batch_size (int, optional, defaults to 1) – Batch size used for inference.
max_length (int, optional) – Maximum sequence length for generation.
generation_config (transformers.GenerationConfig, optional) – Generation configuration object.
logits_processor (optional) – Not used.
return_timestamps (bool, optional) – Whether to return timestamps with the transcribed text.
task (str, optional) – Task for the model (e.g., “transcribe”, “translate”).
language (str, optional) – Language of the input audio.
is_multilingual (bool, optional) – Whether the model is multilingual.
- batch_size: Optional[int] = 1#
- classmethod from_dict(data)#
Create an instance from a dictionary (deserialization).
- classmethod from_json(json_str)#
Create an instance from a JSON string.
- generation_config: Optional[Any] = None#
- is_multilingual = None#
- language = None#
- logits_processor = None#
- max_length: Optional[int] = None#
- replace(**kwargs)#
- return_timestamps = None#
- task = None#
- to_dict()#
Convert the instance to a dictionary for JSON serialization.
- to_json(**kwargs)#
Convert the instance to a JSON string.