easydel.__init__

Contents

easydel.__init__#

class easydel.__init__.ArcticConfig(vocab_size=32000, hidden_size=4096, intermediate_size=14336, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=4096, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=1000000.0, sliding_window=None, attention_dropout=0.0, num_experts_per_tok=1, num_local_experts=8, router_aux_loss_coef=0.001, moe_layer_frequency=2, parallel_attn_mlp_res=False, moe_train_capacity_factor=1, moe_eval_capacity_factor=1, enable_expert_tensor_parallelism=False, moe_min_capacity=0, moe_token_dropping=True, quantization=None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, rope_scaling: Dict[str, Union[str, float]] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the ARCTIC model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 14336) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 4096) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – The index of the padding token in the vocabulary. The default value (0) is the same as for GPT2.

  • bos_token_id (int, optional) – The index of the beginning of sequence token in the vocabulary. The default value (1) is the same as for GPT2.

  • eos_token_id (int, optional) – The index of the end of sequence token in the vocabulary. The default value (2) is the same as for GPT2.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 1e6) – The theta value to use for rotary position embeddings.

  • sliding_window (int, optional) – The sliding window size to use for attention. If not specified, no sliding window attention is used.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • num_experts_per_tok (int, optional, defaults to 1) – The number of experts per token for mixture of experts.

  • num_local_experts (int, optional, defaults to 8) – The number of local experts for mixture of experts.

  • router_aux_loss_coef (float, optional, defaults to 0.001) – The auxiliary loss coefficient for the router.

  • moe_layer_frequency (int, optional, defaults to 2) – The frequency of MoE layers.

  • parallel_attn_mlp_res (bool, optional, defaults to False) – Whether to parallelize attention and MLP residual connections.

  • moe_train_capacity_factor (float, optional, defaults to 1) – The capacity factor for MoE layers during training.

  • moe_eval_capacity_factor (float, optional, defaults to 1) – The capacity factor for MoE layers during evaluation.

  • enable_expert_tensor_parallelism (bool, optional, defaults to False) – Whether to enable expert tensor parallelism.

  • moe_min_capacity (int, optional, defaults to 0) – The minimum capacity for MoE layers.

  • moe_token_dropping (bool, optional, defaults to True) – Whether to drop tokens in MoE layers.

  • quantization (str, optional) – The quantization configuration.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use scan for MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size for scan MLP.

  • bits (int, optional) – The number of bits.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The rope scaling configuration.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'arctic'#
static rng_keys()[source]#
class easydel.__init__.ArcticForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.ArcticModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.AttentionMechanisms(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

AUTO = 'auto'#
BLOCKWISE = 'blockwise'#
CUDA_FLASH_ATTN2 = 'cuda_flash_attn2'#
CUDNN = 'cudnn'#
FLASH_ATTN2 = 'flash_attn2'#
RING = 'ring'#
SDPA = 'sdpa'#
SPLASH = 'splash'#
VANILLA = 'vanilla'#
class easydel.__init__.AttentionMetadata(runtime_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType], runtime_softmax_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, NoneType] = None, sequence_axis_name: str = Ellipsis, mesh: Optional[jax._src.mesh.Mesh] = Ellipsis, platform: easydel.infra.etils.EasyDeLPlatforms = Ellipsis, backend: easydel.infra.etils.EasyDeLBackends = Ellipsis, partition_axis: eformer.escale.partition.constraints.PartitionAxis = Ellipsis, base_config: Optional[easydel.infra.base_config.EasyDeLBaseConfig] = None, scan_ring_attention: bool = Ellipsis, softmax_scale: float = Ellipsis, dropout_prob: float = Ellipsis, blocksize_q: int = Ellipsis, blocksize_k: int = Ellipsis, blocksize_b: int = Ellipsis)[source]#

Bases: object

backend: EasyDeLBackends = Ellipsis#
base_config: Optional[EasyDeLBaseConfig] = None#
blocksize_b: int = Ellipsis#
blocksize_k: int = Ellipsis#
blocksize_q: int = Ellipsis#
dropout_prob: float = Ellipsis#
classmethod from_config(config: EasyDeLBaseConfig, softmax_scale: float, dropout_prob: float = 0.0)[source]#
get_partition_specs(mode: RuntimeType, BTHD: bool = True)[source]#
mesh: Optional[Mesh] = Ellipsis#
partition_axis: PartitionAxis = Ellipsis#
platform: EasyDeLPlatforms = Ellipsis#
replace(**kwargs)#
runtime_dtype: Union[str, type[Any], dtype, SupportsDType]#
runtime_softmax_dtype: Optional[Union[str, type[Any], dtype, SupportsDType]] = None#
scan_ring_attention: bool = Ellipsis#
sequence_axis_name: str = Ellipsis#
set_attrs_carefully(attr_name: str, default: Optional[Any], pickup_name: Optional[str] = None, use_base_config: bool = True)[source]#
softmax_scale: float = Ellipsis#
class easydel.__init__.AttentionRegistry[source]#

Bases: object

Registry for attention implementations.

classmethod create(impl_name: str, metadata: AttentionMetadata) AttentionImpl[source]#

Create an instance of an attention implementation by name.

classmethod get(impl_name: str) Type[AttentionImpl][source]#

Get an attention implementation by name.

classmethod list_implementations() List[str][source]#

List all registered attention implementations.

classmethod register(impl_cls: Type[AttentionImpl]) Type[AttentionImpl][source]#

Decorator to register an attention implementation.

Example usage:

@AttentionRegistry.register class CustomAttention(AttentionImpl):

class easydel.__init__.AutoEasyDeLConfig[source]#

Bases: object

classmethod from_pretrained(pretrained_model_name_or_path: str, sharding_axis_dims: Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: Optional[Sequence[int]] = None, sharding_axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: Optional[PartitionAxis] = None, shard_attention_computation: bool = True, backend: Optional[EasyDeLBackends] = None, platform: Optional[EasyDeLPlatforms] = None, model_task: TaskType = TaskType.CAUSAL_LM, from_torch: bool = False, **kwargs) EasyDeLBaseConfig[source]#

The from_pretrained function is a helper function that allows you to instantiate a model from the pretrained model repository. It takes as input the name of the model (e.g., ‘bert-base-uncased’) and returns an instance of the class corresponding to your model, with all weights loaded from disk.

Parameters
  • cls – Create an instance of the class that called this function.

  • pretrained_model_name_or_path – str: Identify the model in the huggingface model hub.

  • sharding_axis_dims – tp.Sequence[int]: Specify the dimension of each axis in the sharded model_tasking arrays in easydel.

  • shard_attention_computation – bool: whenever to use shard_map for attention.

  • backend – tp.Optional[EasyDeLBackends] : backend to use for model. model_task (TaskType): Task type of model load and find.

  • from_torch – should config be loaded from torch models or not.

  • **kwargs – Pass additional arguments to the model and config classes.

generation process

Returns

A Model Config

class easydel.__init__.AutoEasyDeLModel[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'base-module'#
class easydel.__init__.AutoEasyDeLModelForCausalLM[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

None#

Examples

>>> import jax
>>> from easydel import AutoEasyDeLModelForCausalLM
>>> # Load a GPT-2 model on a single CPU
>>> model = AutoEasyDeLModelForCausalLM.from_pretrained(
>>>   "gpt2", device=jax.devices("cpu")[0]
>>> )
>>> # Load a GPT-2 model sharded across 8 GPUs with data parallelism (DP) and fully sharded data parallelism (FSDP)
>>> model = AutoEasyDeLModelForCausalLM.from_pretrained(
...  "gpt2",
...  sharding_axis_dims=(1, 8, 1, 1),
...  sharding_axis_names=("dp", "fsdp", "tp", "sp"),
...  device=jax.devices("cpu")[0],  # offload to CPU [OPTIONAL]
...  from_torch=True,
>>> )
```
model_task: TaskType = 'causal-language-model'#
class easydel.__init__.AutoEasyDeLModelForImageTextToText[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'image-text-to-text'#
class easydel.__init__.AutoEasyDeLModelForSeq2SeqLM[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'sequence-to-sequence'#
class easydel.__init__.AutoEasyDeLModelForSequenceClassification[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'sequence-classification'#
class easydel.__init__.AutoEasyDeLModelForSpeechSeq2Seq[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

None#

Examples

>>> import jax
>>> from easydel import AutoEasyDeLModelForSpeechSeq2Seq
>>> # Load a openai/whisper-large-v3-turbo sharded
>>> model = AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained(
...  "openai/whisper-large-v3-turbo",
...  auto_shard_model=True,
>>> )
>>> # Load a openai/whisper-large-v3-turbo model sharded across 8 GPUs with data parallelism (DP) and fully sharded data parallelism (FSDP)
>>> model = AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained(
...  "openai/whisper-large-v3-turbo",
...  sharding_axis_dims=(1, 8, 1, 1),
...  sharding_axis_names=("dp", "fsdp", "tp", "sp"),
...  device=jax.devices("cpu")[0],  # offload to CPU [OPTIONAL]
...  from_torch=True,
>>> )
```
model_task: TaskType = 'speech-sequence-to-sequence'#
class easydel.__init__.AutoEasyDeLModelForZeroShotImageClassification[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'zero-shot-image-classification'#
class easydel.__init__.AutoEasyDeLVisionModel[source]#

Bases: BaseAutoEasyModel

This class provides a convenient way to load and shard pretrained models from the Hugging Face Hub and convert them into EasyDeL compatible models. It utilizes the EasyDeL library for distributed training and inference with JAX.

This class inherits from the EasyDeLBaseModule class, providing functionalities for model loading, parameter sharding, and interaction with the EasyDeL framework.

model_task: TaskType = 'vision-module'#
class easydel.__init__.AutoShardAndGatherFunctions[source]#

Bases: object

A class to automatically generate shard and gather functions for a given model configuration.

This class provides two methods to generate shard and gather functions:

  • from_config: Generates functions based on a provided EasyDeLBaseConfig object.

  • from_pretrained: Generates functions based on a pretrained model name or path.

None#
from_config()[source]#

Generates shard and gather functions based on a provided EasyDeLBaseConfig object.

from_pretrained()[source]#

Generates functions based on a pretrained model name or path.

classmethod from_config(config: EasyDeLBaseConfig, partition_rules: Optional[Tuple[Tuple[str, PartitionSpec]]] = None, flatten: bool = True, model_task: TaskType = TaskType.CAUSAL_LM, depth_target: Optional[List[str]] = None)[source]#

Generates shard and gather functions based on a provided EasyDeLBaseConfig object.

Parameters
  • config – An EasyDeLBaseConfig object containing the model configuration.

  • partition_rules – A tuple of tuples containing partition rule names and PartitionSpec objects. If None, uses the default partition rules from the config.

  • flatten – Whether to flatten the shard and gather functions. Defaults to True. model_task (TaskType): Task type of model load and find.

  • depth_target – Pad the sharding to depth, for example make {params:tensor} with depth_target = [“row”] to {row:{params:tensor}}. Defaults to None.

Returns

A tuple containing the shard and gather functions.

static from_params(params, partition_rules, mesh)[source]#
classmethod from_pretrained(pretrained_model_name_or_path: str, sharding_axis_dims: Sequence[int] = (1, -1, 1, 1), sharding_dcn_axis_dims: Optional[Sequence[int]] = None, sharding_axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), partition_axis: Optional[PartitionAxis] = None, shard_attention_computation: bool = True, backend: Optional[EasyDeLBackends] = None, platform: Optional[EasyDeLPlatforms] = None, partition_rules: Optional[Tuple[Tuple[str, PartitionSpec]]] = None, flatten: bool = True, config_kwargs: Optional[Mapping[str, Any]] = None, model_task: TaskType = TaskType.CAUSAL_LM, from_torch: bool = False, trust_remote_code: bool = False) Tuple[Mapping[str, Callable], Mapping[str, Callable]][source]#

Generates shard and gather functions based on a pretrained model name or path.

Parameters
  • pretrained_model_name_or_path – The name or path of the pretrained model.

  • sharding_axis_dims – The dimensions of the sharding axes. Defaults to (1, -1, 1, 1).

  • sharding_axis_names – The names of the sharding axes. Defaults to (“dp”, “fsdp”, “tp”, “sp”).

  • partition_axis (PartitionAxis) – PartitionAxis is new module used for partitioning arrays in easydel.

  • shard_attention_computation – Whether to shard the attention computation. Defaults to True.

  • backend – The backend to use for custom kernels. Defaults to None.

  • partition_rules – A tuple of tuples containing partition rule names and PartitionSpec objects. If None, uses the default partition rules from the config.

  • flatten – Whether to flatten the shard and gather functions. Defaults to True.

  • config_kwargs – Additional keyword arguments to pass to the AutoEasyDeLConfig constructor. Defaults to None. model_task (TaskType): Task type of model load and find. from_torch: should config be loaded from torch models or not.

  • trust_remote_code (bool) – whenever to trust remote code loaded from HF.

Returns

A tuple containing the shard and gather functions.

class easydel.__init__.AutoState[source]#

Bases: BaseAutoEasyState

class easydel.__init__.AutoStateForCausalLM[source]#

Bases: BaseAutoEasyState

class easydel.__init__.AutoStateForImageSequenceClassification[source]#

Bases: BaseAutoEasyState

class easydel.__init__.AutoStateForImageTextToText[source]#

Bases: BaseAutoEasyState

class easydel.__init__.AutoStateForSeq2SeqLM[source]#

Bases: BaseAutoEasyState

class easydel.__init__.AutoStateForSpeechSeq2Seq[source]#

Bases: BaseAutoEasyState

class easydel.__init__.AutoStateForZeroShotImageClassification[source]#

Bases: BaseAutoEasyState

class easydel.__init__.AutoStateVisionModel[source]#

Bases: BaseAutoEasyState

class easydel.__init__.AyaVisionConfig(vision_config=None, text_config=None, vision_feature_select_strategy='full', vision_feature_layer=-1, downsample_factor=2, adapter_layer_norm_eps=1e-06, image_token_index=255036, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the configuration of a [AyaVisionForConditionalGeneration]. It is used to instantiate an AyaVision model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of AyaVision. e.g. [CohereForAI/aya-vision-8b](https://huggingface.co/CohereForAI/aya-vision-8b)

Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.

Parameters
  • vision_config (Union[AutoConfig, dict], optional, defaults to CLIPVisionConfig) – The config object or dictionary of the vision backbone.

  • text_config (Union[AutoConfig, dict], optional, defaults to LlamaConfig) – The config object or dictionary of the text backbone.

  • vision_feature_select_strategy (str, optional, defaults to “full”) – The feature selection strategy used to select the vision feature from the vision backbone. Can be one of “default” or “full”. If “default”, the CLS token is removed from the vision features. If “full”, the full vision features are used.

  • vision_feature_layer (int, optional, defaults to -1) – The index of the layer to select the vision feature.

  • downsample_factor (int, optional, defaults to 2) – The downsample factor to apply to the vision features.

  • adapter_layer_norm_eps (float, optional, defaults to 1e-06) – The epsilon value used for layer normalization in the adapter.

  • image_token_index (int, optional, defaults to 255036) – The image token index to encode the image prompt.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'aya_vision'#
sub_configs: Dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.auto.auto_configuration.AutoEasyDeLConfig'>, 'vision_config': <class 'easydel.__init__.modules.auto.auto_configuration.AutoEasyDeLConfig'>}#
class easydel.__init__.AyaVisionForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

get_image_features(pixel_values: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#
loss_type = 'ForCausalLM'#
prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds

Returns

A dictionary of the past_key_values, attention_mask and position ids

update_inputs_for_generation(model_outputs, model_kwargs)[source]#
class easydel.__init__.BaseTrainer(arguments: tp.Optional[TrainingArguments] = None, model_state: tp.Optional[EasyDeLState] = None, model: tp.type[EasyDeLBaseModule] = None, dataset_train: tp.Optional[Dataset] = None, dataset_eval: tp.Optional[Dataset] = None, data_collator: tp.Optional[tp.Callable] = None, finetune: bool = True, checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]] = None, **deprecated_kwargs)[source]#

Bases: BaseTrainerProtocol

apply_training_hooks(metrics: LossMetrics) LossMetrics[source]#

Apply training hooks to the model.

calculate_number_total_flops(params, is_training=True)[source]#
compile_aot() bool[source]#

Compile the state ahead of time for faster execution.

configure_dataloaders() TrainerConfigureDataloaderOutput[source]#

Configures the dataloaders for training and evaluation.

This method creates the training and evaluation dataloaders using the provided datasets and data collator. It also determines the maximum number of training and evaluation steps based on the dataset sizes and training arguments.

Returns

An object containing the configured dataloaders and the

maximum number of training and evaluation steps.

Return type

TrainerConfigureDataloaderOutput

abstract configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

configure_model() TrainerConfigureModelOutput[source]#

Configures the model, optimizer, scheduler, and configuration.

This method retrieves the model configuration from the model state, creates the optimizer and scheduler using the training arguments, and returns an object containing the configured model, optimizer, scheduler, and configuration.

Returns

An object containing the configured model, optimizer, scheduler, and configuration.

Return type

TrainerConfigureModelOutput

static count_model_parameters(prm)[source]#

Prints the number of model parameters in billions.

abstract create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#

Creates a function to collect and process batches of data for training or evaluation.

This function handles padding or truncating sequences to the specified max_sequence_length based on the chosen truncation_mode.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode. Defaults to “keep_end”.

Returns

A function that takes a batch of data and returns a processed batch.

Return type

tp.Callable

create_progress_bar(total: int, desc: str = '', disabled: bool = False) BaseProgressBar[source]#

Create a progress bar of the specified type.

property evaluation_batch_size#
static finish()[source]#

Finalize the training process.

get_runstage_flops(is_training) Union[float, Tuple[float, bool]][source]#

Return the total number of FLOPs for the model.

initialize_trainer_utils()[source]#

Initializes various utilities used by the trainer.

This includes setting up Weights & Biases, initializing the training timer, configuring dataloaders, configuring the model and optimizer, sharding the model and reference model states, and configuring the training and evaluation functions.

property is_process_zero#
log_metrics(metrics: Any, pbar: BaseProgressBar, step: int, mode: str = 'train')[source]#

Log metrics and update progress bar.

log_weight_distribution(state: EasyDeLState, step: int)[source]#

Log distribution of weights.

property mesh#
property model#
on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#

hook process to call in start of the step.

on_step_start(state: EasyDeLState, step: int) EasyDeLState[source]#

hook process to call in start of the step.

save_information(output_path: Union[str, Path]) None[source]#

Save the generated information to a markdown file.

Parameters

output_path – Path where the markdown file should be saved

save_pretrained(state: EasyDeLState, save_directory: Optional[str] = None, gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[Callable]]] = None, to_torch: bool = False, easystate_to_huggingface_model_kwargs: Optional[dict] = None, torch_save_pretrained_kwargs: Optional[dict] = None)[source]#

Saves the model state as a checkpoint file or to a Torch compatible directory.

specs_to_name_sharding(tree, mesh=None)[source]#

Convert specs to named sharding.

start_evaluation_hook()[source]#

Hook to run before evaluation starts.

start_training_hook()[source]#

Hook to run before training starts.

property training_batch_size#
class easydel.__init__.CLIPConfig(text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs)[source]#

Bases: EasyDeLBaseConfig

[CLIPConfig] is the configuration class to store the configuration of a [CLIPModel]. It is used to instantiate a CLIP model according to the specified arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the CLIP [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • text_config (dict, optional) – Dictionary of configuration options used to initialize [CLIPTextConfig].

  • vision_config (dict, optional) – Dictionary of configuration options used to initialize [CLIPVisionConfig].

  • projection_dim (int, optional, defaults to 512) – Dimensionality of text and vision projection layers.

  • logit_scale_init_value (float, optional, defaults to 2.6592) – The initial value of the logit_scale parameter. Default is used as per the original CLIP implementation.

  • kwargs (optional) – Dictionary of keyword arguments.

Example:

```python >>> from transformers import CLIPConfig, CLIPModel

>>> # Initializing a CLIPConfig with openai/clip-vit-base-patch32 style configuration
>>> configuration = CLIPConfig()
>>> # Initializing a CLIPModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
>>> model = CLIPModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
>>> # We can also initialize a CLIPConfig from a CLIPTextConfig and a CLIPVisionConfig
>>> from transformers import CLIPTextConfig, CLIPVisionConfig
>>> # Initializing a CLIPText and CLIPVision configuration
>>> config_text = CLIPTextConfig()
>>> config_vision = CLIPVisionConfig()
>>> config = CLIPConfig.from_text_vision_configs(config_text, config_vision)
```
classmethod from_text_vision_configs(text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs)[source]#

Instantiate a [CLIPConfig] (or a derived class) from clip text model configuration and clip vision model configuration.

Returns

An instance of a configuration object

Return type

[CLIPConfig]

get_partition_rules(*arg, **kwargs)#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'clip'#
sub_configs: Dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.clip.clip_configuration.CLIPTextConfig'>, 'vision_config': <class 'easydel.__init__.modules.clip.clip_configuration.CLIPVisionConfig'>}#
class easydel.__init__.CLIPForImageClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.CLIPModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

compute_loss(*, labels=None, loss_config=None, loss_kwargs=None, **batch) Tuple[Any, FlaxCLIPOutput][source]#

basic compute_loss call

get_image_features(pixel_values: Union[Array, ndarray, bool, number])[source]#
get_text_features(input_ids: Union[Array, ndarray, bool, number], attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, position_ids: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
class easydel.__init__.CLIPTextConfig(vocab_size=49408, hidden_size=512, intermediate_size=2048, projection_dim=512, num_hidden_layers=12, num_attention_heads=8, max_position_embeddings=77, hidden_act='quick_gelu', layer_norm_eps=1e-05, attention_dropout=0.0, initializer_range=0.02, initializer_factor=1.0, pad_token_id=1, bos_token_id=49406, eos_token_id=49407, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the configuration of a [CLIPTextModel]. It is used to instantiate a CLIP text encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the text encoder of the CLIP [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 49408) – Vocabulary size of the CLIP text model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [CLIPModel].

  • hidden_size (int, optional, defaults to 512) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 2048) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • projection_dim (int, optional, defaults to 512) – Dimensionality of text and vision projection layers.

  • num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 8) – Number of attention heads for each attention layer in the Transformer encoder.

  • max_position_embeddings (int, optional, defaults to 77) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).

  • hidden_act (str or function, optional, defaults to “quick_gelu”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “selu” and “gelu_new” “quick_gelu” are supported.

  • layer_norm_eps (float, optional, defaults to 1e-05) – The epsilon used by the layer normalization layers.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • initializer_factor (float, optional, defaults to 1.0) – A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing).

  • pad_token_id (int, optional, defaults to 1) – Padding token id.

  • bos_token_id (int, optional, defaults to 49406) – Beginning of stream token id.

  • eos_token_id (int, optional, defaults to 49407) – End of stream token id.

Example:

```python >>> from transformers import CLIPTextConfig, CLIPTextModel

>>> # Initializing a CLIPTextConfig with openai/clip-vit-base-patch32 style configuration
>>> configuration = CLIPTextConfig()
>>> # Initializing a CLIPTextModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
>>> model = CLIPTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
base_config_key: str = 'text_config'#
get_partition_rules(*arg, **kwargs)#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'clip_text_model'#
class easydel.__init__.CLIPTextModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.CLIPTextModelWithProjection(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.CLIPVisionConfig(hidden_size=768, intermediate_size=3072, projection_dim=512, num_hidden_layers=12, num_attention_heads=12, num_channels=3, image_size=224, patch_size=32, hidden_act='quick_gelu', layer_norm_eps=1e-05, attention_dropout=0.0, initializer_range=0.02, initializer_factor=1.0, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the configuration of a [CLIPVisionModel]. It is used to instantiate a CLIP vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the vision encoder of the CLIP [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) architecture.

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • projection_dim (int, optional, defaults to 512) – Dimensionality of text and vision projection layers.

  • num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_channels (int, optional, defaults to 3) – The number of input channels.

  • image_size (int, optional, defaults to 224) – The size (resolution) of each image.

  • patch_size (int, optional, defaults to 32) – The size (resolution) of each patch.

  • hidden_act (str or function, optional, defaults to “quick_gelu”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “selu” and “gelu_new” “quick_gelu” are supported.

  • layer_norm_eps (float, optional, defaults to 1e-05) – The epsilon used by the layer normalization layers.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • initializer_factor (float, optional, defaults to 1.0) – A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing).

Example:

```python >>> from transformers import CLIPVisionConfig, CLIPVisionModel

>>> # Initializing a CLIPVisionConfig with openai/clip-vit-base-patch32 style configuration
>>> configuration = CLIPVisionConfig()
>>> # Initializing a CLIPVisionModel (with random weights) from the openai/clip-vit-base-patch32 style configuration
>>> model = CLIPVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
base_config_key: str = 'vision_config'#
get_partition_rules(*arg, **kwargs)#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'clip_vision_model'#
class easydel.__init__.CLIPVisionModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Cohere2Config(vocab_size=256000, hidden_size=8192, intermediate_size=22528, logit_scale=0.0625, num_hidden_layers=40, num_attention_heads=64, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=8192, initializer_range=0.02, layer_norm_eps=1e-05, use_cache=True, pad_token_id=0, bos_token_id=5, eos_token_id=255001, tie_word_embeddings=True, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, sliding_window=4096, sliding_window_pattern=4, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 256000) – Vocabulary size of the Cohere model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 8192) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 22528) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • logit_scale (float, optional, defaults to 0.0625) – A logit scale value used in the attention layer.

  • num_hidden_layers (int, optional, defaults to 40) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 64) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 8192) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 5) – The index of the beginning of sequence token in the vocabulary.

  • eos_token_id (int, optional, defaults to 255001) – The index of the end of sequence token in the vocabulary.

  • tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • attention_bias (bool, optional, defaults to False) – Whether to use attention bias.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • use_qk_norm (bool, optional, defaults to False) – Whether to use query and key normalization.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'cohere'#
static rng_keys()[source]#
class easydel.__init__.Cohere2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Cohere2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Cohere2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.CohereConfig(vocab_size=256000, hidden_size=8192, intermediate_size=22528, logit_scale=0.0625, num_hidden_layers=40, num_attention_heads=64, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=8192, initializer_range=0.02, layer_norm_eps=1e-05, use_cache=True, pad_token_id=0, bos_token_id=5, eos_token_id=255001, tie_word_embeddings=True, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, use_qk_norm: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 256000) – Vocabulary size of the Cohere model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 8192) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 22528) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • logit_scale (float, optional, defaults to 0.0625) – A logit scale value used in the attention layer.

  • num_hidden_layers (int, optional, defaults to 40) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 64) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 8192) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 5) – The index of the beginning of sequence token in the vocabulary.

  • eos_token_id (int, optional, defaults to 255001) – The index of the end of sequence token in the vocabulary.

  • tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • attention_bias (bool, optional, defaults to False) – Whether to use attention bias.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • use_qk_norm (bool, optional, defaults to False) – Whether to use query and key normalization.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'cohere'#
static rng_keys()[source]#
class easydel.__init__.CohereForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.CohereForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.CohereModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.ConfigType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

MODULE_CONFIG = 'module-config'#
class easydel.__init__.DPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: ~typing.Optional[str] = None, clip_grad: ~typing.Optional[float] = None, custom_scheduler: ~typing.Optional[~typing.Callable[[int], ~typing.Any]] = None, dataloader_num_workers: ~typing.Optional[int] = 0, dataloader_pin_memory: ~typing.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: ~typing.Optional[int] = None, evaluation_steps: ~typing.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: ~typing.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: ~typing.Optional[~typing.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: ~typing.Optional[dict] = None, learning_rate: float = 5e-05, learning_rate_end: ~typing.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: ~typing.Optional[~easydel.infra.loss_utils.LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: ~typing.Optional[int] = None, max_sequence_length: ~typing.Optional[int] = 4096, max_training_steps: ~typing.Optional[int] = None, model_name: str = 'BaseTrainer', model_parameters: ~typing.Optional[dict] = None, metrics_to_show_in_rich_pbar: ~typing.Optional[~typing.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: ~typing.Literal['adafactor', 'lion', 'adamw', 'rmsprop'] = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: ~typing.Any = None, process_zero_is_admin: bool = True, progress_bar_type: ~typing.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: ~typing.Optional[int] = None, save_total_limit: ~typing.Optional[int] = None, scheduler: ~typing.Literal['linear', 'cosine', 'none'] = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: ~typing.Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo', state_apply_fn_kwarguments_to_model: ~typing.Optional[dict] = None, step_partition_spec: ~jax._src.partition_spec.PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: ~typing.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: ~typing.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: ~typing.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: ~typing.Optional[~numpy.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: ~typing.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0)[source]#

Bases: TrainingArguments

Configuration class for Direct Preference Optimization (DPO) training.

Inherits from TrainingArguments and adds parameters specific to DPO training as described in https://arxiv.org/abs/2305.18290. This configuration controls various aspects of the DPO training process including loss computation, model architecture, and dataset processing.

beta#

Temperature parameter (β) controlling deviation from reference model. Higher values make training focus more on preference matching. Default: 0.1

Type

float

label_smoothing#

Smoothing factor for labels in loss calculation. Helps prevent overconfidence. 0.0 means no smoothing. Default: 0.0

Type

float

loss_type#

Type of contrastive loss function to use. Valid options: ‘sigmoid’, ‘hinge’, ‘ipo’, ‘exo_pair’, ‘nca_pair’, ‘robust’, ‘bco_pair’, ‘sppo_hard’, ‘aot’, ‘aot_pair’, ‘apo_zero’, ‘apo_down’. Default: ‘sigmoid’

Type

LOSS_FN_VARIENTS

use_weighting#

Whether to apply example weighting in loss calculation. Default: False

Type

bool

label_pad_token_id#

Token ID used for padding labels. Default: -100

Type

int

padding_value#

Value used for padding sequences. If None, uses model’s default padding token. Default: None

Type

int | None

max_length#

Maximum total sequence length (prompt + completion). Default: 512

Type

int | None

max_prompt_length#

Maximum length for prompt sequences. Default: 256

Type

int | None

max_completion_length#

Maximum length for completion sequences. Auto-calculated as max_length - max_prompt_length if None. Default: None

Type

int | None

is_encoder_decoder#

Explicitly set if model is encoder-decoder. Auto-detected if None. Default: None

Type

bool | None

disable_dropout#

Whether to disable dropout during training for deterministic behavior. Default: True

Type

bool

precompute_ref_log_probs#

Whether to precompute reference model log probabilities before training. Default: False

Type

bool

dataset_num_proc#

Number of processes for dataset preprocessing. Default: None (sequential processing)

Type

int | None

reference_free#

Whether to use reference-free variant of DPO. Default: False

Type

bool

force_use_ref_model#

Force use reference model even when reference_free=True. Default: False

Type

bool

sync_ref_model#

Whether to periodically sync reference model with training model. Default: False

Type

bool

learning_rate#

Optimizer learning rate. Default: 1e-6

Type

float

ref_model_mixup_alpha#

Alpha parameter for mixup between policy and reference models. Default: 0.9

Type

float

ref_model_sync_steps#

Number of steps between reference model syncs. Default: 64

Type

int

rpo_alpha#

Alpha parameter for Relative Preference Optimization. None disables RPO. Default: None

Type

float | None

tools#

Additional tools for training process

Type

list[dict | Callable] | None

Example

>>> config = DPOConfig(
...   beta=0.2, loss_type="ipo", max_length=1024, learning_rate=5e-6
... )
beta: float = Field(name=None,type=None,default=0.1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Temperature parameter (β) controlling deviation from reference model. Higher values make training focus more on preference matching.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
dataset_num_proc: Optional[int] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of processes for dataset preprocessing. Default: None (sequential processing)'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
disable_dropout: bool = Field(name=None,type=None,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to disable dropout during training for deterministic behavior.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
force_use_ref_model: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Force use reference model even when reference_free=True.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
is_encoder_decoder: Optional[bool] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Explicitly set if model is encoder-decoder. Auto-detected if None.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
label_pad_token_id: int = Field(name=None,type=None,default=-100,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Token ID used for padding labels.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
label_smoothing: float = Field(name=None,type=None,default=0.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Smoothing factor for labels in loss calculation. Helps prevent overconfidence. 0.0 means no smoothing.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
learning_rate: float = Field(name=None,type=None,default=1e-06,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Optimizer learning rate.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
loss_type: Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = Field(name=None,type=None,default='sigmoid',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Type of contrastive loss function to use. Valid options: 'sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'."}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
max_completion_length: Optional[int] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Maximum length for completion sequences. Auto-calculated as max_length - max_prompt_length if None.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
max_length: Optional[int] = Field(name=None,type=None,default=512,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Maximum total sequence length (prompt + completion).'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
max_prompt_length: Optional[int] = Field(name=None,type=None,default=256,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Maximum length for prompt sequences.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
model_name: str = Field(name=None,type=None,default='DPOTrainer',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The name of the model.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
padding_value: Optional[int] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Value used for padding sequences. If None, uses model's default padding token."}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
precompute_ref_log_probs: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to precompute reference model log probabilities before training.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
ref_model_mixup_alpha: float = Field(name=None,type=None,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Alpha parameter for mixup between policy and reference models.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
ref_model_sync_steps: int = Field(name=None,type=None,default=64,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of steps between reference model syncs.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
reference_free: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to use reference-free variant of DPO.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
replace(**kwargs)#
rpo_alpha: Optional[float] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Alpha parameter for Relative Preference Optimization. None disables RPO.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
sync_ref_model: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to periodically sync reference model with training model.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
tools: Optional[List[Union[dict, Callable]]] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Additional tools for training process.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
use_weighting: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to apply example weighting in loss calculation.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
class easydel.__init__.DPOTrainer(arguments: DPOConfig, model: Union[EasyDeLBaseModule, EasyDeLState], reference_model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, processing_class: Optional[Any] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Any] = None, data_collator: Optional[Callable] = None)[source]#

Bases: Trainer

Trainer for Direct Preference Optimization (DPO).

This trainer handles the training, evaluation, and checkpointing of language models using the DPO algorithm. It supports sharding, gradient accumulation, mixed precision training, LoRA, and precomputed reference model log probabilities.

arguments: DPOConfig#
compute_reference_log_probs(state: EasyDeLState, padded_batch: Dict) tuple[Any, Any][source]#

Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.

Parameters
  • state (EasyDeLState) – The EasyDeLState object of the model (used if no reference model is provided).

  • padded_batch (tp.Dict) – The padded batch of data.

Returns

A tuple containing the log probabilities for the chosen and rejected responses.

Return type

tuple[tp.Any, tp.Any]

configure_dataloaders()[source]#

Returns the training dataloader, potentially with precomputed reference log probabilities.

If precompute_ref_log_probs is enabled, this method computes the reference model’s log probabilities for the chosen and rejected responses in the training dataset and adds them as columns to the dataset.

Returns

The training dataloader.

Return type

tensorflow.data.Dataset

configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Creates a data collection function for batching.

For DPO training, this method simply returns the pre-configured data_collator.

Parameters
  • max_sequence_length (int) – The maximum sequence length (not used in this implementation).

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.

Returns

The data collator function.

Return type

tp.Callable

on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#

hook process to call in start of the step.

static process_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#
static tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#

Tokenize a row of the dataset.

Parameters
  • features (dict[str, str]) – Row of the dataset, should contain the keys “prompt”, “chosen”, and “rejected”.

  • processing_class (PreTrainedTokenizerBase) – Processing class used to process the data.

  • max_prompt_length (int or None) – Maximum length of the prompt sequence. If None, the prompt sequence is not truncated.

  • max_completion_length (int or None) – Maximum length of the completion sequences. If None, the completion sequences are not truncated.

  • add_special_tokens (bool) – Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If True, the prompt sequence will have a bos token prepended and an eos token appended. In any case, the completion sequences will have an eos token appended.

Returns

Tokenized sequences with the keys “prompt_input_ids”, “chosen_input_ids”, and `”rejected_input_ids”.

Return type

dict[str, list[int]]

class easydel.__init__.DbrxAttentionConfig(attn_pdrop: float = 0, clip_qkv: Optional[float] = 8, kv_n_heads: int = 1, rope_theta: float = 10000.0, **kwargs: Any)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the attention related configuration of a [DbrxModel].

Parameters
  • attn_pdrop (float, optional, defaults to 0.0) – The dropout probability applied to the attention output.

  • clip_qkv (float, optional, defaults to 8.0) – The clip value applied to the query, key, and value tensors.

  • kv_n_heads (int, optional, defaults to 1) – The number of attention heads for the key and value tensors.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value for the rotary position embedding.

classmethod from_pretrained(pretrained_model_name_or_path: str, **kwargs: Any) PretrainedConfig[source]#

Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.

Parameters
  • pretrained_model_name_or_path (str or os.PathLike) –

    This can be either:

    • a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.

    • a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.

    • a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.

  • cache_dir (str or os.PathLike, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.

  • force_download (bool, optional, defaults to False) – Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.

  • resume_download – Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.

  • proxies (Dict[str, str], optional) – A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.

  • token (str or bool, optional) – The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).

  • revision (str, optional, defaults to “main”) –

    The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.

    <Tip>

    To test a pull request you made on the Hub, you can pass `revision=”refs/pr/<pr_number>”.

    </Tip>

  • return_unused_kwargs (bool, optional, defaults to False) –

    If False, then this function returns just the final configuration object.

    If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.

  • subfolder (str, optional, defaults to “”) – In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.

  • kwargs (Dict[str, tp.Any], optional) – The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.

Returns

The configuration object instantiated from this pretrained model.

Return type

[PretrainedConfig]

Examples:

>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
>>> # derived class: BertConfig
>>> config = BertConfig.from_pretrained(
...   "google-bert/bert-base-uncased"
>>> )  # Download configuration from huggingface.co and cache.
>>> config = BertConfig.from_pretrained(
...   "./test/saved_model/"
>>> )  # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
>>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
>>> config = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased", output_attentions=True, foo=False
>>> )
>>> assert config.output_attentions == True
>>> config, unused_kwargs = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased",
...  output_attentions=True,
...  foo=False,
...  return_unused_kwargs=True,
>>> )
>>> assert config.output_attentions == True
>>> assert unused_kwargs == {"foo": False}

```

class easydel.__init__.DbrxConfig(d_model: int = 2048, n_heads: int = 16, n_layers: int = 24, max_seq_len: int = 2048, vocab_size: int = 32000, resid_pdrop: float = 0.0, emb_pdrop: float = 0.0, attn_config: Optional[DbrxAttentionConfig] = None, ffn_config: Optional[DbrxFFNConfig] = None, use_cache: bool = True, initializer_range: float = 0.02, output_router_logits: bool = False, router_aux_loss_coef: float = 0.05, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs: Any)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • d_model (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.

  • n_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.

  • n_layers (int, optional, defaults to 24) – Number of hidden layers in the Transformer encoder.

  • max_seq_len (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the DBRX model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • emb_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • attn_config ([DbrxAttentionConfig], optional) – The configuration of the attention layer.

  • ffn_config ([DbrxFFNConfig], optional) – The configuration of the feed forward layer.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • output_router_logits (bool, optional, defaults to False) – Whether or not to output the router logits.

  • router_aux_loss_coef (float, optional, defaults to 0.05) – The coefficient of the router auxiliary loss.

attribute_map: Dict[str, str] = {'hidden_size': 'd_model', 'max_position_embeddings': 'max_seq_len', 'num_attention_heads': 'n_heads', 'num_hidden_layers': 'n_layers'}#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'dbrx'#
class easydel.__init__.DbrxFFNConfig(ffn_act_fn: Optional[dict] = None, ffn_hidden_size: int = 3584, moe_num_experts: int = 4, moe_top_k: int = 1, moe_jitter_eps: Optional[float] = None, moe_loss_weight: float = 0.01, moe_normalize_expert_weights: Optional[float] = 1, uniform_expert_assignment: bool = False, **kwargs: Any)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the feed forward related configuration of a [DbrxModel].

Parameters
  • ffn_act_fn (dict, optional) – The activation function configuration for the feed-forward network.

  • ffn_hidden_size (int, optional, defaults to 3584) – The hidden size of the feed-forward network.

  • moe_num_experts (int, optional, defaults to 4) – The number of experts in the Mixture-of-Experts (MoE) layer.

  • moe_top_k (int, optional, defaults to 1) – The number of top experts to use in the MoE layer.

  • moe_jitter_eps (float, optional) – The jitter epsilon value for the MoE layer.

  • moe_loss_weight (float, optional, defaults to 0.01) – The loss weight for the MoE auxiliary loss.

  • moe_normalize_expert_weights (float, optional, defaults to 1.0) – The normalization factor for the expert weights in the MoE layer.

  • uniform_expert_assignment (bool, optional, defaults to False) – Whether to use uniform expert assignment in the MoE layer.

classmethod from_pretrained(pretrained_model_name_or_path: str, **kwargs: Any) EasyDeLBaseConfig[source]#

Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.

Parameters
  • pretrained_model_name_or_path (str or os.PathLike) –

    This can be either:

    • a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.

    • a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.

    • a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.

  • cache_dir (str or os.PathLike, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.

  • force_download (bool, optional, defaults to False) – Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.

  • resume_download – Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.

  • proxies (Dict[str, str], optional) – A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.

  • token (str or bool, optional) – The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).

  • revision (str, optional, defaults to “main”) –

    The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.

    <Tip>

    To test a pull request you made on the Hub, you can pass `revision=”refs/pr/<pr_number>”.

    </Tip>

  • return_unused_kwargs (bool, optional, defaults to False) –

    If False, then this function returns just the final configuration object.

    If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.

  • subfolder (str, optional, defaults to “”) – In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.

  • kwargs (Dict[str, tp.Any], optional) – The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.

Returns

The configuration object instantiated from this pretrained model.

Return type

[PretrainedConfig]

Examples:

>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
>>> # derived class: BertConfig
>>> config = BertConfig.from_pretrained(
...   "google-bert/bert-base-uncased"
>>> )  # Download configuration from huggingface.co and cache.
>>> config = BertConfig.from_pretrained(
...   "./test/saved_model/"
>>> )  # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
>>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
>>> config = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased", output_attentions=True, foo=False
>>> )
>>> assert config.output_attentions == True
>>> config, unused_kwargs = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased",
...  output_attentions=True,
...  foo=False,
...  return_unused_kwargs=True,
>>> )
>>> assert config.output_attentions == True
>>> assert unused_kwargs == {"foo": False}

```

class easydel.__init__.DbrxForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.DbrxForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.DbrxModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies#

Returns frequency values from the config.

class easydel.__init__.DeepseekV2Config(vocab_size=102400, hidden_size=4096, intermediate_size=11008, moe_intermediate_size=1407, num_hidden_layers=30, num_attention_heads=32, num_key_value_heads=32, n_shared_experts=None, n_routed_experts=None, ep_size=1, routed_scaling_factor=1.0, kv_lora_rank=512, q_lora_rank=1536, qk_rope_head_dim=64, v_head_dim=128, qk_nope_head_dim=128, topk_method='gready', n_group=None, topk_group=None, num_experts_per_tok=None, moe_layer_freq=1, first_k_dense_replace=0, norm_topk_prob=False, scoring_func='softmax', aux_loss_alpha=0.001, seq_aux=True, hidden_act='silu', max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=None, bos_token_id=100000, eos_token_id=100001, pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, rope_scaling: Dict[str, Union[str, float]] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 102400) – Vocabulary size of the DeepseekV2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 11008) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • moe_intermediate_size (int, optional, defaults to 1407) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the MoE layer.

  • num_hidden_layers (int, optional, defaults to 30) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 32) – Number of key and value heads for each attention layer in the Transformer encoder.

  • n_shared_experts (int, optional) – Number of shared experts.

  • n_routed_experts (int, optional) – Number of routed experts.

  • ep_size (int, optional, defaults to 1) – Expert parallel size.

  • routed_scaling_factor (float, optional, defaults to 1.0) – Routed scaling factor.

  • kv_lora_rank (int, optional, defaults to 512) – KV LoRA rank.

  • q_lora_rank (int, optional, defaults to 1536) – Q LoRA rank.

  • qk_rope_head_dim (int, optional, defaults to 64) – QK rope head dimension.

  • v_head_dim (int, optional, defaults to 128) – V head dimension.

  • qk_nope_head_dim (int, optional, defaults to 128) – QK nope head dimension.

  • topk_method (str, optional, defaults to “gready”) – Top-k method.

  • n_group (int, optional) – Number of groups.

  • topk_group (int, optional) – Top-k group.

  • num_experts_per_tok (int, optional) – Number of experts per token.

  • moe_layer_freq (int, optional, defaults to 1) – MoE layer frequency.

  • first_k_dense_replace (int, optional, defaults to 0) – First k dense replace.

  • norm_topk_prob (bool, optional, defaults to False) – Whether to normalize top-k probabilities.

  • scoring_func (str, optional, defaults to “softmax”) – Scoring function.

  • aux_loss_alpha (float, optional, defaults to 0.001) – Auxiliary loss alpha.

  • seq_aux (bool, optional, defaults to True) – Whether to use sequence auxiliary loss.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 100000) – The index of the beginning of sequence token in the vocabulary.

  • eos_token_id (int, optional, defaults to 100001) – The index of the end of sequence token in the vocabulary.

  • pretraining_tp (int, optional, defaults to 1) – Pretraining TP.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • attention_bias (bool, optional, defaults to False) – Whether to use attention bias.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use scan for MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size for scan MLP.

  • bits (int, optional) – The number of bits to quantize the model to.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The rope scaling configuration.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'deepseek_v2'#
static rng_keys()[source]#
class easydel.__init__.DeepseekV2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.DeepseekV2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies#

Returns frequency values from the config.

class easydel.__init__.DeepseekV3Config(vocab_size=129280, hidden_size=7168, intermediate_size=18432, moe_intermediate_size=2048, num_hidden_layers=61, num_nextn_predict_layers=1, num_attention_heads=128, num_key_value_heads=128, n_shared_experts=1, n_routed_experts=256, ep_size=1, routed_scaling_factor=2.5, kv_lora_rank=512, q_lora_rank=1536, qk_rope_head_dim=64, v_head_dim=128, qk_nope_head_dim=128, topk_method='noaux_tc', n_group=8, topk_group=4, num_experts_per_tok=8, moe_layer_freq=1, first_k_dense_replace=3, norm_topk_prob=True, scoring_func='sigmoid', aux_loss_alpha=0.001, seq_aux=True, hidden_act='silu', max_position_embeddings=4096, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=None, bos_token_id=0, eos_token_id=1, pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the configuration of a [DeepseekV3Model]. It is used to instantiate an DeepSeek model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the DeepSeek-V3. Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information. :param vocab_size: Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the

inputs_ids passed when calling [DeepseekV3Model]

Parameters
  • hidden_size (int, optional, defaults to 4096) – Dimension of the hidden representations.

  • intermediate_size (int, optional, defaults to 11008) – Dimension of the MLP representations.

  • moe_intermediate_size (int, optional, defaults to 1407) – Dimension of the MoE representations.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer decoder.

  • num_nextn_predict_layers (int, optional, defaults to 1) – Number of nextn predict layers in the DeepSeekV3 Model.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer decoder.

  • n_shared_experts (int, optional, defaults to None) – Number of shared experts, None means dense model.

  • n_routed_experts (int, optional, defaults to None) – Number of routed experts, None means dense model.

  • routed_scaling_factor (float, optional, defaults to 1.0) – Scaling factor or routed experts.

  • topk_method (str, optional, defaults to gready) – Topk method used in routed gate.

  • n_group (int, optional, defaults to None) – Number of groups for routed experts.

  • topk_group (int, optional, defaults to None) – Number of selected groups for each token(for each token, ensuring the selected experts is only within topk_group groups).

  • num_experts_per_tok (int, optional, defaults to None) – Number of selected experts, None means dense model.

  • moe_layer_freq (int, optional, defaults to 1) – The frequency of the MoE layer: one expert layer for every moe_layer_freq - 1 dense layers.

  • first_k_dense_replace (int, optional, defaults to 0) –

    Number of dense layers in shallow layers(embed->dense->dense->…->dense->moe->moe…->lm_head).

    --k dense layers–/

  • norm_topk_prob (bool, optional, defaults to False) – Whether to normalize the weights of the routed experts.

  • scoring_func (str, optional, defaults to ‘softmax’) – Method of computing expert weights.

  • aux_loss_alpha (float, optional, defaults to 0.001) – Auxiliary loss weight coefficient.

  • = (seq_aux) – Whether to compute the auxiliary loss for each individual sample.

  • num_key_value_heads (int, optional) – This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) in the decoder.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-06) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – Padding token id.

  • bos_token_id (int, optional, defaults to 1) – Beginning of stream token id.

  • eos_token_id (int, optional, defaults to 2) – End of stream token id.

  • pretraining_tp (int, optional, defaults to 1) – Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is necessary to ensure exact reproducibility of the pretraining results. Please refer to [this issue](pytorch/pytorch#76232).

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie weight embeddings

  • rope_theta (float, optional, defaults to 10000.0) – The base period of the RoPE embeddings.

  • rope_scaling (Dict, optional) – Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is {“type”: strategy name, “factor”: scaling factor}. When using this flag, don’t update max_position_embeddings to the expected new maximum.

  • attention_bias (bool, defaults to False, optional, defaults to False) – Whether to use a bias in the query, key, value and output projection layers during self-attention.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

`python >>> from transformers import DeepseekV3Model, DeepseekV3Config >>> # Initializing a Deepseek-V3 style configuration >>> configuration = DeepseekV3Config() >>> # Accessing the model configuration >>> configuration = model.config `

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

keys_to_ignore_at_inference = ['past_key_values']#
model_type: str = 'deepseek_v3'#
class easydel.__init__.DeepseekV3ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.DeepseekV3Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies#

Returns frequency values from the config.

class easydel.__init__.EasyDeLBackends(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enum defining available backends for EasyDeL. Each enum member represents a different kernel usage approach.

CPU = 'cpu'#
GPU = 'gpu'#
TPU = 'tpu'#
class easydel.__init__.EasyDeLBaseConfig(axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), attn_mechanism: ~typing.Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa'] = 'vanilla', blocksize_k: int = 128, blocksize_q: int = 128, blocksize_b: int = 1, partition_axis: ~eformer.escale.partition.constraints.PartitionAxis = PartitionAxis(batch_axis=('fsdp', 'dp'), sequence_axis='sp', query_sequence_axis='sp', head_axis='tp', key_sequence_axis='sp', hidden_state_axis='tp', attention_dim_axis=None, bias_head_sequence_axis=None, bias_key_sequence_axis=None, generation_query_sequence_axis=None, generation_head_axis='tp', generation_key_sequence_axis='sp', generation_attention_dim_axis=None), shard_attention_computation: bool = True, use_sharded_kv_caching: bool = False, use_sharding_constraint: bool = False, backend: ~typing.Optional[~easydel.__init__.infra.etils.EasyDeLBackends] = None, platform: ~typing.Optional[~easydel.__init__.infra.etils.EasyDeLPlatforms] = None, easy_method: ~typing.Literal['train', 'serve', 'convert'] = 'train', bits: ~typing.Optional[int] = None, scan_ring_attention: bool = True, scan_attention_layers: bool = False, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, sequence_axis_name: str = 'sp', gradient_checkpointing: ~easydel.__init__.infra.etils.EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, kv_cache_quantization_method: ~easydel.__init__.infra.etils.EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, kv_cache_quantization_blocksize: int = 64, quantization_method: ~easydel.__init__.infra.etils.EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, quantization_pattern: str = '.*', quantization_blocksize: int = 64, kv_cache_sharding_sequence_axis_name: ~typing.Union[str, ~typing.Tuple[str, ...]] = 'sp', flash_attention_backward_pass_impl: ~typing.Literal['triton', 'xla'] = 'triton', attn_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, attn_softmax_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, fcm_max_ratio: float = 0.0, fcm_min_ratio: float = 0.0, hardware_abstraction: bool = False, pallas_m_block_size: int = 128, pallas_k_block_size: int = 128, pallas_n_block_size: int = 128, **kwargs)[source]#

Bases: PretrainedConfig

Initialize the configuration for EasyDeL. :param axis_dims: Dimensions of the axes. Default is (1, -1, 1, 1). :type axis_dims: tp.Sequence[int] :param axis_names: Names of the axes. Default is (“dp”, “fsdp”, “tp”, “sp”). :type axis_names: tp.Sequence[str] :param attn_mechanism: Attention mechanism to use. Default is DEFAULT_ATTENTION_MECHANISM. :type attn_mechanism: AVAILABLE_ATTENTION_MECHANISMS :param blocksize_k: Block size for key. Default is 128. :type blocksize_k: int :param blocksize_q: Block size for query. Default is 128. :type blocksize_q: int :param blocksize_b: Block size for batch. Default is 1. :type blocksize_b: int :param partition_axis: Partition axis configuration. Default is PartitionAxis(). :type partition_axis: PartitionAxis :param shard_attention_computation: Whether to shard attention computation. Default is True. :type shard_attention_computation: bool :param use_sharded_kv_caching: Whether to use sharded key-value caching. Default is False. :type use_sharded_kv_caching: bool :param use_sharding_constraint: Whether to use sharding constraint. Default is False. :type use_sharding_constraint: bool :param backend: Backend to use. Default is None. :type backend: tp.Optional[EasyDeLBackends] :param platform: Platform to use. Default is None. :type platform: tp.Optional[EasyDeLPlatforms] :param easy_method: Method to use. Default is EasyMethod.TRAIN. :type easy_method: tp.Literal[“train”, “serve”, “convert”] :param bits: Number of bits for quantization. Default is None. :type bits: tp.Optional[int] :param scan_ring_attention: Whether to scan ring attention. Default is True. :type scan_ring_attention: bool :param scan_attention_layers: Whether to scan attention layers. Default is False. :type scan_attention_layers: bool :param use_scan_mlp: Whether to use scan MLP. Default is False. :type use_scan_mlp: bool :param scan_mlp_chunk_size: Chunk size for scan MLP. Default is 1024. :type scan_mlp_chunk_size: int :param sequence_axis_name: Name of the attention axis. Default is “sp”. :type sequence_axis_name: str :param gradient_checkpointing: Gradient checkpointing method. Default is EasyDeLGradientCheckPointers.NONE. :type gradient_checkpointing: EasyDeLGradientCheckPointers :param kv_cache_quantization_method: Key-value cache quantization method. Default is EasyDeLQuantizationMethods.NONE. :type kv_cache_quantization_method: EasyDeLQuantizationMethods :param kv_cache_quantization_blocksize: Block size for key-value cache quantization. Default is 64. :type kv_cache_quantization_blocksize: int :param quantization_method: Quantization method. Default is EasyDeLQuantizationMethods.NONE. :type quantization_method: EasyDeLQuantizationMethods :param quantization_pattern: Pattern for quantization. Default is “.*”. :type quantization_pattern: str :param quantization_blocksize: Block size for quantization. Default is 64. :type quantization_blocksize: int :param kv_cache_sharding_sequence_axis_name: Name of the key-value cache sharding sequence axis. Default is “sp”. :type kv_cache_sharding_sequence_axis_name: tp.Union[str, tp.Tuple[str, …]] :param flash_attention_backward_pass_impl: Implementation for flash attention backward pass. Default is “triton”. :type flash_attention_backward_pass_impl: tp.Literal[“triton”, “xla”] :param attn_dtype: Data type for attention. Default is device half. :type attn_dtype: jnp.dtype :param attn_softmax_dtype: Data type for softmax ops in attention. Default is jnp.float32. :type attn_softmax_dtype: jnp.dtype :param fcm_max_ratio: Maximum ratio for FCM. Default is 0.0. :type fcm_max_ratio: float :param fcm_min_ratio: Minimum ratio for FCM. Default is 0.0. :type fcm_min_ratio: float :param hardware_abstraction: Whether to use hardware abstraction. Default is DEFAULT_HARDWARE_ABSTRACTION. :type hardware_abstraction: bool :param pallas_m_block_size: Block size for Pallas M. Default is DEFAULT_PALLAS_M_BLOCK_SIZE. :type pallas_m_block_size: int :param pallas_k_block_size: Block size for Pallas K. Default is DEFAULT_PALLAS_K_BLOCK_SIZE. :type pallas_k_block_size: int :param pallas_n_block_size: Block size for Pallas N. Default is DEFAULT_PALLAS_N_BLOCK_SIZE. :type pallas_n_block_size: int :param **kwargs: Additional keyword arguments.

Raises

Warning – If kv_cache_quantization_method is not NONE and use_sharded_kv_caching is True.

add_basic_configurations(axis_dims: Sequence[int] = Ellipsis, dcn_axis_dims: Optional[Sequence[int]] = Ellipsis, axis_names: Sequence[str] = Ellipsis, attn_mechanism: Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa'] = Ellipsis, blocksize_k: int = Ellipsis, blocksize_q: int = Ellipsis, blocksize_b: int = Ellipsis, partition_axis: PartitionAxis = Ellipsis, shard_attention_computation: bool = Ellipsis, use_sharded_kv_caching: bool = Ellipsis, backend: Optional[EasyDeLBackends] = Ellipsis, platform: Optional[EasyDeLPlatforms] = Ellipsis, easy_method: Literal['train', 'serve', 'convert'] = Ellipsis, bits: Optional[int] = Ellipsis, scan_ring_attention: bool = Ellipsis, scan_attention_layers: bool = Ellipsis, use_sharding_constraint: bool = Ellipsis, use_scan_mlp: bool = Ellipsis, scan_mlp_chunk_size: int = Ellipsis, sequence_axis_name: str = Ellipsis, gradient_checkpointing: EasyDeLGradientCheckPointers = Ellipsis, kv_cache_quantization_method: EasyDeLQuantizationMethods = Ellipsis, kv_cache_quantization_blocksize: int = Ellipsis, quantization_method: EasyDeLQuantizationMethods = Ellipsis, quantization_blocksize: int = Ellipsis, quantization_pattern: str = Ellipsis, kv_cache_sharding_sequence_axis_name: Union[str, Tuple[str, ...]] = Ellipsis, flash_attention_backward_pass_impl: Literal['triton', 'xla'] = Ellipsis, attn_dtype: dtype = Ellipsis, attn_softmax_dtype: dtype = Ellipsis, hardware_abstraction: bool = Ellipsis, pallas_m_block_size: int = Ellipsis, pallas_k_block_size: int = Ellipsis, pallas_n_block_size: int = Ellipsis, **kwargs)[source]#

It initializes all the attributes of an object, and it’s called when you create a new instance of that class.

Parameters
  • axis_dims (tp.Sequence[int], optional) – Specify the number of dimensions for each axis. Defaults to (1, -1, 1, 1).

  • axis_names (tp.Sequence[str], optional) – Set the names of the axes. Defaults to (“dp”, “fsdp”, “tp”, “sp”).

  • attn_mechanism (AVAILABLE_ATTENTION_MECHANISMS, optional) – attention mechanism to use. Defaults to DEFAULT_ATTENTION_MECHANISM.

  • blocksize_k (int, optional) – block size of key_states. Defaults to 128.

  • blocksize_q (int, optional) – block size of query_states. Defaults to 128.

  • blocksize_b (int, optional) – block size of bias. Defaults to 1.

  • partition_axis (PartitionAxis, optional) – PartitionAxis is new module used for partitioning arrays in easydel. Defaults to PartitionAxis().

  • shard_attention_computation (bool, optional) – whenever to use shard_map for attention. Defaults to True.

  • use_sharded_kv_caching (bool, optional) – whenever to use shard_map and sharding for key and value. Defaults to True.

  • backend (tp.Optional[EasyDeLBackends], optional) – Specify the backend to use. Defaults to None.

  • platform (tp.Optional[EasyDeLPlatforms], optional) – Specify the platform to used to use. Defaults to None.

  • easy_method (tp.Literal["train", "serve", "convert"], optional) – easydel Quantization Method to be applied for. Defaults to EasyMethod.TRAIN.

  • bits (tp.Optional[int], optional) – Model bits for quantization. Defaults to None.

  • scan_ring_attention (bool, optional) – Whether to use can for ring attention. Defaults to True.

  • scan_attention_layers (bool, optional) – Whether to use can for attention layers. Defaults to False.

  • use_sharding_constraint (bool, optional) – whether to use sharding constraint for the arrays. Defaults to False.

  • use_scan_mlp (bool, optional) – Determine whether to use scan_mlp or not. Defaults to False.

  • scan_mlp_chunk_size (int, optional) – Size of chunks in scan MLP. Defaults to 1024.

  • sequence_axis_name (str, optional) – Name of the attention axis name. Defaults to “sp”.

  • gradient_checkpointing (EasyDeLQuantizationMethods, optional) – Gradient Checkpointing method for created or loaded module (applied on mlp and attn layers most of the times).

  • kv_cache_quantization_method (EasyDeLQuantizationMethods, optional) – key and value quantization type. Defaults to EasyDeLQuantizationMethods.NONE.

  • kv_cache_quantization_blocksize (int, optional) – size of kv cache quantization. Defaults to 64.

  • quantization_method (EasyDeLQuantizationMethods, optional) – linear modules quantization type. Defaults to EasyDeLQuantizationMethods.NONE.

  • quantization_blocksize (int, optional) – size of linear quantization. Defaults to 64.

  • quantization_pattern (str) – re pattern to be used for quantizing layers.

  • kv_cache_sharding_sequence_axis_name (tp.Union[str, tp.Tuple[str, ...]], optional) – axis name to target for sharding sequences. Defaults to “sp”.

  • flash_attention_backward_pass_impl (tp.Literal["triton", "xla"], optional) – Specify the backward pass kernel for flash attention. Defaults to “triton”.

  • attn_dtype (jnp.dtype, optional) – Data type for attention computations. Defaults to device half.

  • attn_softmax_dtype (jnp.dtype, optional) – Data type for softmax in attention op computations. Defaults to jnp.float32.

  • fcm_max_ratio (float, optional) – Maximum ratio for flash cross attention. Defaults to 0.0.

  • fcm_min_ratio (float, optional) – Minimum ratio for flash cross attention. Defaults to 0.0.

  • hardware_abstraction (bool, optional) – whenever to switch to custom pallas kernels instead of JAX. Defaults to DEFAULT_HARDWARE_ABSTRACTION.

  • pallas_m_block_size (int, optional) – block size m dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_M_BLOCK_SIZE.

  • pallas_k_block_size (int, optional) – block size k dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_K_BLOCK_SIZE.

  • pallas_n_block_size (int, optional) – block size n dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_N_BLOCK_SIZE.

attach_custom_arguments(**kwargs)[source]#
static create_mesh(axis_dims: Sequence[int] = (1, -1, 1, 1), axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), dcn_axis_dims: Optional[Sequence[int]] = None, process_is_granule: bool = False, should_sort_granules_by_key: bool = True, allow_split_physical_axes: bool = True, backend: Optional[str] = None)[source]#

The create_mesh function creates a mesh object that can be used to shard arrays.

Returns

A mesh object

classmethod from_pretrained(pretrained_model_name_or_path: Union[str, PathLike], cache_dir: Optional[Union[str, PathLike]] = None, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[bool, str]] = None, revision: str = 'main', **kwargs) PretrainedConfig[source]#

Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.

Parameters
  • pretrained_model_name_or_path (str or os.PathLike) –

    This can be either:

    • a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.

    • a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.

    • a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.

  • cache_dir (str or os.PathLike, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.

  • force_download (bool, optional, defaults to False) – Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.

  • resume_download – Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.

  • proxies (Dict[str, str], optional) – A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.

  • token (str or bool, optional) – The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).

  • revision (str, optional, defaults to “main”) –

    The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.

    <Tip>

    To test a pull request you made on the Hub, you can pass `revision=”refs/pr/<pr_number>”.

    </Tip>

  • return_unused_kwargs (bool, optional, defaults to False) –

    If False, then this function returns just the final configuration object.

    If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.

  • subfolder (str, optional, defaults to “”) – In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.

  • kwargs (Dict[str, tp.Any], optional) – The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.

Returns

The configuration object instantiated from this pretrained model.

Return type

[PretrainedConfig]

Examples:

>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
>>> # derived class: BertConfig
>>> config = BertConfig.from_pretrained(
...   "google-bert/bert-base-uncased"
>>> )  # Download configuration from huggingface.co and cache.
>>> config = BertConfig.from_pretrained(
...   "./test/saved_model/"
>>> )  # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
>>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
>>> config = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased", output_attentions=True, foo=False
>>> )
>>> assert config.output_attentions == True
>>> config, unused_kwargs = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased",
...  output_attentions=True,
...  foo=False,
...  return_unused_kwargs=True,
>>> )
>>> assert config.output_attentions == True
>>> assert unused_kwargs == {"foo": False}

```

get_axis_dims() Sequence[int][source]#

The get_axis_dims function returns a sequence of integers representing the dimensions of each axis.

Parameters

self – Represent the instance of the class

Returns

The dimensions of the axes

get_axis_names() Sequence[str][source]#

The get_axis_names function returns a list of the names of the axes.

Parameters

self – Represent the instance of the class

Returns

A list of the names of all axes

get_backend() str[source]#

The get_backend function returns the backend that is currently being used. If no backend has been set, it will return the default JAX backend.

Parameters

self – Bind the method to an object

Returns

The backend platform

get_basic_causal_mask(*args, **kwargs)[source]#
get_basic_frequencies(head_size: Optional[int] = None, rotary_dim: Optional[int] = None, base: Optional[float] = None) Any[source]#

Get basic frequencies for rotary embeddings.

Parameters
  • head_size – Size of attention heads (defaults to self.head_dim)

  • rotary_dim – Dimension for rotary embeddings (defaults to head_size)

  • base – Base value for frequency computation (defaults to self.rope_theta)

Returns

ModuleCaches instance containing computed frequencies

get_basic_rope(dtype: Union[Array, ndarray, bool, number], head_size: int, rotary_dim: Optional[int] = None, is_neox_style: bool = True, base: Optional[float] = None)[source]#

Get basic rotary position embeddings.

Parameters
  • dtype – Data type for the embeddings

  • head_size – Size of attention heads

  • rotary_dim – Dimension for rotary embeddings (defaults to head_size)

  • is_neox_style – Whether to use NeoX style embeddings

  • base – Base value for frequency computation (defaults to self.rope_theta)

Returns

Rotary position embeddings func

get_fcm_mask(batch_size, seq_length, deterministic: bool)[source]#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
jax_mesh()[source]#
property mesh#

The mesh property is a helper property that creates a Mesh object from the axis_dims and axis_names attributes of an object, which are assumed to be lists of integers and strings, respectively. The platform attribute is also used if it exists.

Parameters

self – Refer to the object itself

Returns

A jaxMesh

read_basics_from_config(config: EasyDeLBaseConfig)[source]#
to_dict() Dict[str, Any][source]#

Serializes this instance to a Python dictionary.

Returns

Dictionary of all the attributes that make up this configuration instance.

Return type

Dict[str, Any]

class easydel.__init__.EasyDeLBaseConfigDict[source]#

Bases: TypedDict

class easydel.__init__.EasyDeLBaseModule(*args: Any, **kwargs: Any)[source]#

Bases: Module, BaseModuleProtocol, EasyBridgeMixin, EasyGenerationMixin

Base class for EasyDeL modules, providing common functionalities for model initialization, parameter handling, and integration with the EasyDeL ecosystem.

apply_lora_to_layers(lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = False, rngs: Optional[Rngs] = None) SELF[source]#

Applies LoRA (Low-Rank Adaptation) to specified linear layers within a model.

property causal_mask: Array#

Returns a causal mask from the config.

compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#

basic compute_loss call

float(change_runtime_dtype: bool = True) SELF[source]#

Converts Model paramters to float32.

property frequencies: Array#

Returns frequency values from the config.

fully_gather() SELF[source]#
fully_shard(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None) SELF[source]#
gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#

Gathers the model’s parameters based on the specified partitioning rules and mesh.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules for gathering.

  • mesh (jax.sharding.Mesh, optional) – The mesh to gather from. overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]]): Overlay functions to apply to the model.

Returns

The gathered model.

Return type

EasyDeLBaseModule

get_static_arguments() Tuple[source]#

return static arguments kwargs for jax.jit

property graphdef: Union[NodeDef[Node], NodeRef[Node]]#
property graphother: State[Key, VariableState[Any]]#
property graphstate: State[Key, VariableState[Any]]#
property graphtree_params_shape: Dict#

Evaluates the shape of the model’s parameters and returns a dictionary.

property graphtree_shape: Dict#

Evaluates the shape of the modeland returns a dictionary.

half(change_runtime_dtype: bool = True) SELF[source]#

Converts Model paramters to float16.

classmethod lazy_init(*args, **kwargs) SELF[source]#

initialize the base class with nnx.eval_shape carefully

property loss_function#
merge_lora_params(pytree: Dict) SELF[source]#

Merge Given Pytree (LoRA Params) with current LoRA Module.

merge_params(tree)[source]#

merge state to the current model

merge_params_dict(params_dict: Dict) SELF[source]#

Merges the model parameters from a dictionary into the current model.

Parameters

params_dict (tp.Dict) – A dictionary containing the parameters to merge.

Returns

The model with merged parameters.

Return type

EasyDeLBaseModule

property mesh: Mesh#

Returns the mesh from the config.

property model_task: Optional[str]#

Returns the model task.

property model_type: Optional[str]#

Returns the model type.

property module_dtype: dtype#
property parameters: Dict#
property params: Dict#
property params_sharding: Dict#

return the sharding of the model parameters

prepare_inputs_for_call(**kwargs)[source]#

update inputs for calling model

property pure_transform_fn#

generates a pure transform function for converting torch to easydel module.

quantize(method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT, block_size: int = 128, quantization_pattern: Optional[str] = None, quantize_tensors: bool = True, verbose: Optional[bool] = None) SELF[source]#

Quantizes the model’s linear layers.

Parameters
  • method (EasyDeLQuantizationMethods, optional) – The quantization method to use.

  • block_size (int, optional) – The block size for quantization.

  • quantization_pattern (str, optional) – The quantization pattern to use. quantize_tensors (bool): whenever to quantize tensors or quantize Linear Layers.` verbose (bool, optional): Verbose quantizing process

Returns

The quantized model.

Return type

EasyDeLBaseModule

shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#

Shards the model’s parameters using the specified partitioning rules and mesh.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules for sharding.

  • mesh (jax.sharding.Mesh, optional) – The mesh to shard across. overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]]): Overlay functions to apply to the model.

Returns

The sharded model.

Return type

EasyDeLBaseModule

split_lora_params() Dict[source]#

split Given Module (LoRA Module) and return LoRA Params.

split_params()[source]#

split the model parameters

split_params_dict(extract_fn: Optional[Callable] = None, remove_none: bool = True) Dict[source]#

Splits the model parameters and returns them as a dictionary, removing VariableState from the tree.

Parameters
  • extract_fn (tp.Optional[tp.Callable], optional) – Function to extract values from the parameters.

  • remove_none (bool, optional) – Whether to remove None values from the dictionary.

Returns

The dictionary of split parameters.

Return type

tp.Dict

property static_arguments: Tuple#
to_dtype(dtype: dtype) SELF[source]#

Applies sharding functions to the model’s state.

to_state() Any[source]#

converts current model to a EasyDeLState

to_torch(**kwargs)[source]#

converts current model to a huggingface torch model

property transform_fn#

generate transform function for converting torch to easydel module.

unwrap_lora_to_layers(verbose: bool = False) SELF[source]#

UnWrap LoRA (Low-Rank Adaptation) from specified linear layers within a model.

class easydel.__init__.EasyDeLGradientCheckPointers(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enum defining available gradient checkpointing strategies for EasyDeL. Each enum member represents a different checkpointing approach.

CHECKPOINT_DOTS = 'checkpoint_dots'#
CHECKPOINT_DOTS_WITH_NO_BATCH_DMIS = 'checkpoint_dots_with_no_batch_dims'#
EVERYTHING_SAVEABLE = 'everything_saveable'#
NONE = ''#
NOTHING_SAVEABLE = 'nothing_saveable'#
class easydel.__init__.EasyDeLOptimizers(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enum defining available optimizers for EasyDeL. Each enum member represents a different optimization algorithm.

ADAFACTOR = 'adafactor'#
ADAMW = 'adamw'#
LION = 'lion'#
RMSPROP = 'rmsprop'#
class easydel.__init__.EasyDeLPlatforms(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enum defining available platforms for EasyDeL. Each enum member represents a different kernel usage approach.

JAX = 'jax'#
PALLAS = 'pallas'#
TRITON = 'triton'#
class easydel.__init__.EasyDeLQuantizationMethods(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enum defining available quantization strategies for EasyDeL. Each enum member represents a different quantization approach.

A8BIT = '8bit'#
NF4 = 'nf4'#
NONE = 'None'#
exception easydel.__init__.EasyDeLRuntimeError[source]#

Bases: Exception

class easydel.__init__.EasyDeLSchedulers(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enum defining available schedulers for EasyDeL. Each enum member represents a different learning rate schedule.

COSINE = 'cosine'#
LINEAR = 'linear'#
NONE = 'None'#
class easydel.__init__.EasyDeLState(step: int | jax.Array, graphdef: nn.GraphDef, graphstate: nn.GraphState, graphother: nn.GraphState, tx: optax.GradientTransformation, opt_state: tp.Optional[optax.OptState], apply_fn: tp.Optional[tp.Callable] = None)[source]#

Bases: PyTreeNode

EasyDeLState A Snapshot of Your EasyDeL Model

The EasyDeLState class acts like a comprehensive container that holds all the essential information about your EasyDeL model at a given point in time. Think of it as a snapshot of your model. It includes

apply_fn: tp.Optional[tp.Callable] = None#
apply_gradients(*, grads)[source]#

Applies gradients to the model parameters and updates the optimizer state. This function is typically called during training to update the model based on the computed gradients.

Parameters

grads – A dictionary of gradients, where keys correspond to model parameters.

Returns

An updated EasyDeLState object with modified parameters and optimizer state.

Return type

EasyDeLState

classmethod create(*, step: tp.Optional[int] = None, graphdef: tp.Optional[nn.GraphDef] = None, graphstate: tp.Optional[nn.GraphState] = None, graphother: tp.Optional[nn.GraphState] = None, model: tp.Optional[nn.Module] = None, tx: tp.Optional[optax.GradientTransformation] = None, opt_state: tp.Optional[optax.OptState] = None, init_opt_state: bool = False) EasyDeLState[source]#

Create an instance with flexible initialization options.

Parameters
  • step – Optional number of training steps.

  • graphdef – Optional graph definition.

  • graphstate – Optional graph state.

  • graphother – Optional graph *others.

  • model – Optional neural network module.

  • tx – Optional gradient transformation.

  • opt_state – Optional optimizer state.

Raises

ValueError – If initialization parameters are inconsistent.

gather_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#

Gathers the model according to the provided partition rules.

Returns

An updated EasyDeLState object with the gathered model.

Return type

EasyDeLState

gather_optimizer_state(partition_rules=None)[source]#
gather_state()[source]#

Gathers the entire state.

Returns

An updated EasyDeLState object with the gathered state.

Return type

EasyDeLState

graphdef: nn.GraphDef#
graphother: nn.GraphState#
graphstate: nn.GraphState#
init_tx(tx: GradientTransformation, partition_rules: Any = None) EasyDeLState[source]#

Initialize the optimizer state with the given gradient transformation.

Parameters
  • tx (optax.GradientTransformation) – A gradient transformation to initialize the optimizer state.

  • partition_rules (Optional[Any], optional) – Rules for partitioning the optimizer state. Defaults to None.

Returns

An updated EasyDeLState object with the new gradient transformation and sharded optimizer state.

Return type

EasyDeLState

load_optimizer(load_directory: Union[str, PathLike])[source]#
load_state(load_directory: Union[str, PathLike], verbose: bool = True)[source]#
merge(tree) Any[source]#
merge_to_state(tree) EasyDeLState[source]#
property model: Any#
opt_state: tp.Optional[optax.OptState]#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

save_state(save_directory: Union[str, PathLike], float_dtype: Optional[dtype] = None, verbose: bool = True, mismatch_allowed: bool = True, save_optimizer: bool = True, enable: Optional[bool] = None)[source]#
shard_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#

Shards the model according to the provided partition rules.

Parameters
  • partition_rules (Optional[Any]) – The partition rules to be used for sharding. If None, the method will use the partition rules from self.model.config.

  • mesh (Optional[Mesh]) – The mesh to be used for sharding. If None, the method will use the mesh from self.model.

Returns

An updated EasyDeLState object with the sharded model.

Return type

EasyDeLState

shard_optimizer_state(opt_state: Optional[Any] = None, partition_rules: Any = None) Any[source]#

Shards the optimizer state according to the provided partition rules.

Parameters
  • opt_state (Optional[Any]) – The optimizer state to be sharded. If None, the method will use self.opt_state. Raises a ValueError if both opt_state and self.opt_state are None.

  • partition_rules (Optional[Any]) – The partition rules to be used for sharding. If None, the method will use the partition rules from self.model.config.

Returns

The sharded optimizer state.

Return type

Any

Raises

ValueError – If both opt_state and self.opt_state are None.

shard_state(partition_rules: Any = None) EasyDeLState[source]#

Shards the entire state, according to the provided partition rules.

Parameters

partition_rules (Optional[Any]) – The partition rules to be used for sharding. If None, the method will use the partition rules from self.model.config.

Returns

An updated EasyDeLState object with the sharded state.

Return type

EasyDeLState

shard_with_shape(shape) EasyDeLState[source]#

shard current state with a given shape

property shardings#

Returns the sharding information for the state.

Returns

The sharding information.

Return type

Any

property size: int#

Calculates the total size of the optimizer state and model graph state.

Returns

The total size in bytes.

Return type

int

step: int | jax.Array#
tx: optax.GradientTransformation#
exception easydel.__init__.EasyDeLSyntaxRuntimeError[source]#

Bases: Exception

exception easydel.__init__.EasyDeLTimerError[source]#

Bases: Exception

class easydel.__init__.ExaoneConfig(vocab_size: int = 102400, hidden_size: int = 2048, intermediate_size: int = 14336, num_layers: int = 32, num_attention_heads: int = 32, num_key_value_heads: int = 8, activation_function='silu', max_position_embeddings=2048, initializer_range=0.02, layer_norm_epsilon=1e-05, use_cache=True, embed_dropout: float = 0.0, pad_token_id: Optional[int] = None, bos_token_id: int = 1, eos_token_id: int = 2, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling: Dict[str, Union[str, float]] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, attention_dropout: float = 0.0, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 102400) – Vocabulary size of the Exaone model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 14336) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • head_dim (int, defaults to 128) – Dimensionality of the head for attention.

  • num_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 8) – Number of key and value heads for each attention layer in the Transformer encoder.

  • activation_function (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • bits (int, optional) – The number of bits to quantize the model to.

  • attention_bias (bool, optional, defaults to False) – Whether to use bias in the attention layer.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, attention_dropout: float = 0.0, rope_scaling: Dict[str, Union[str, float]] = None, attention_bias: bool = False, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the model:

Parameters
  • gradient_checkpointing (str) – Determine whether to use gradient checkpointing

  • use_scan_mlp (bool) – Determine whether to use the scan_mlp function or notn

  • scan_mlp_chunk_size (int) – Chunk the input to the mlp

  • bits (tp.Optional[int]) – Specify the number of bits to use for quantization

  • attention_dropout (float) – Set the dropout rate for the attention layer

  • attention_bias (bool) – when ever to use attention_bias

  • rope_scaling (tp.Dict[str, tp.Union[str, float]]) – rope_scaling for rope

Return type

A tuple of the following

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'exaone'#
static rng_keys()[source]#
class easydel.__init__.ExaoneForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.ExaoneForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.ExaoneModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies#

Returns frequency values from the config.

class easydel.__init__.FalconConfig(vocab_size=65024, hidden_size=4544, num_hidden_layers=32, num_attention_heads=71, num_ln_in_parallel_attn=None, layer_norm_epsilon=1e-05, initializer_range=0.02, use_cache=True, hidden_dropout=0.0, attention_dropout=0.0, num_kv_heads=None, alibi=False, new_decoder_architecture=False, multi_query=True, parallel_attn=True, bias=False, max_position_embeddings=2048, rope_theta=10000.0, rope_scaling=None, bos_token_id=11, eos_token_id=11, ffn_hidden_size=None, ff_factor=None, activation='gelu', gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 65024) – Vocabulary size of the Falcon model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4544) – Dimensionality of the encoder layers and the pooler layer.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 71) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_ln_in_parallel_attn (int, optional) – The number of layer norms in the parallel attention layer.

  • layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • hidden_dropout (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • num_kv_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.

  • alibi (bool, optional) – Whether to use alibi attention.

  • new_decoder_architecture (bool, optional) – Whether to use the new decoder architecture.

  • multi_query (bool, optional, defaults to True) – Whether to use multi-query attention.

  • parallel_attn (bool, optional, defaults to True) – Whether to use parallel attention.

  • bias (bool, optional, defaults to False) – Whether to use bias in the linear layers.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The rope scaling configuration.

  • bos_token_id (int, optional, defaults to 11) – The index of the beginning of sequence token in the vocabulary.

  • eos_token_id (int, optional, defaults to 11) – The index of the end of sequence token in the vocabulary.

  • ffn_hidden_size (int, optional) – Dimensionality of the hidden layer in the FFN

  • ff_factor (int, optional) – The scaling factor of the FFN

  • activation (str, optional, defaults to “gelu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • gradient_checkpointing (str, optional, defaults to “”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
attribute_map: Dict[str, str] = {'num_attention_heads': 'num_attention_heads', 'num_hidden_layers': 'num_hidden_layers'}#
static get_mesh_names()[source]#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'falcon'#
property rotary#
class easydel.__init__.FalconForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.FalconModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.FlexibleAttentionModule(*args: Any, **kwargs: Any)[source]#

Bases: Module

Manages different attention mechanisms for efficient computation in EasyDeL models.

This class serves as a central hub for handling various attention mechanisms, including optimized implementations like FlashAttention, SplashAttention, RingAttention, and more traditional approaches like vanilla (dot-product) attention. It provides a unified interface to select and execute the appropriate attention mechanism based on the model’s configuration and hardware platform.

Key Features:

  • Attention Mechanism Selection: Supports a wide range of attention mechanisms, allowing users to choose the most suitable option based on performance and hardware constraints.

  • Sharding and Partitioning: Integrates with JAX’s sharding capabilities, enabling efficient distribution of computations and data across multiple devices.

  • Block-wise Computation: Implements block-wise attention computations for optimized memory usage and speed, particularly beneficial for large models.

  • Performance Optimization: Includes support for highly optimized implementations like FlashAttention, SplashAttention, and RingAttention for TPU and GPU acceleration.

  • Flexibility and Customization: Offers fine-grained control over attention parameters, sharding specifications, and block sizes, providing flexibility for different use cases.

  • Testing and Evaluation: Includes a run_attention_benchmarks method to systematically evaluate different attention mechanisms and help users identify the best-performing option.

The FlexibleAttentionModule class is a crucial component within EasyDeL, responsible for managing and optimizing attention computations. It provides a user-friendly way to select and execute different attention mechanisms, leveraging JAX’s sharding capabilities and offering performance enhancements through specialized implementations

like FlashAttention and SplashAttention. Its ability to handle block-wise computations and customization options

makes it adaptable to a variety of model architectures and hardware configurations.

forward(query_states: Union[Array, ndarray, bool, number], key_states: Union[Array, ndarray, bool, number], value_states: Union[Array, ndarray, bool, number], bias: Optional[Union[Array, ndarray, bool, number]] = None, init_bias: Optional[Callable[[], Union[Array, ndarray, bool, number]]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, segment_ids: Optional[Union[Array, ndarray, bool, number]] = None, causal: bool = True, dropout_rng: Optional[PRNGKey] = None) AttentionOutput[source]#
class easydel.__init__.GPT2Config(vocab_size=50257, n_positions=1024, n_embd=768, n_layer=12, n_head=12, n_inner=None, activation_function='gelu_new', resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, layer_norm_epsilon=1e-05, initializer_range=0.02, summary_type='cls_index', summary_use_proj=True, summary_activation=None, summary_proj_to_labels=True, summary_first_dropout=0.1, scale_attn_weights=True, use_cache=True, bos_token_id=50256, eos_token_id=50256, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, tie_word_embeddings: bool = False, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50257) – Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • n_positions (int, optional, defaults to 1024) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • n_embd (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.

  • n_layer (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • n_head (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.

  • n_inner (int, optional) – Dimensionality of the inner feed-forward layers.

  • activation_function (str, optional, defaults to “gelu_new”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • resid_pdrop (float, optional, defaults to 0.1) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • embd_pdrop (float, optional, defaults to 0.1) – The dropout ratio for the embeddings.

  • attn_pdrop (float, optional, defaults to 0.1) – The dropout ratio for the attention probabilities.

  • layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon to use in the layer normalization layers.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • summary_type (str, optional, defaults to “cls_index”) – The summary type to use. Possible values are “cls_index” (equivalent to the output of the last token of the first sentence in a sequence) and “last” (equivalent to the output of the last token in the sequence).

  • summary_use_proj (bool, optional, defaults to True) – Whether to use a projection after the vector extraction.

  • summary_activation (str, optional) – The activation to use for the summary.

  • summary_proj_to_labels (bool, optional, defaults to True) – Whether to project the summary to the labels.

  • summary_first_dropout (float, optional, defaults to 0.1) – The dropout ratio to be used after the projection and activation.

  • scale_attn_weights (bool, optional, defaults to True) – Scale attention weights by dividing by sqrt(hidden_size).

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • bos_token_id (int, optional, defaults to 50256) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 50256) – The id of the end-of-sequence token.

  • scale_attn_by_inverse_layer_idx (bool, optional, defaults to False) – Whether to scale attention weights by 1 / layer_idx + 1.

  • reorder_and_upcast_attn (bool, optional, defaults to False) – Whether to reorder and upcast attention.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • bits (int, optional) – The number of bits to quantize the model to.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
attribute_map: Dict[str, str] = {'hidden_size': 'n_embd', 'max_position_embeddings': 'n_positions', 'num_attention_heads': 'n_head', 'num_hidden_layers': 'n_layer'}#
get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

keys_to_ignore_at_inference = ['past_key_values']#
model_type: str = 'gpt2'#
class easydel.__init__.GPT2LMHeadModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.GPT2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.GPTJConfig(vocab_size: int = 50400, n_positions: int = 2048, n_embd: int = 4096, n_layer: int = 28, n_head: int = 16, rotary_dim: int = 64, n_inner: int = None, activation_function: str = 'gelu_new', resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, attn_pdrop: float = 0.0, layer_norm_epsilon: float = 1e-05, initializer_range: int = 0.02, use_cache: int = True, bos_token_id: int = 50256, eos_token_id: int = 50256, tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50400) – Vocabulary size of the GPT-J model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • n_positions (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • n_embd (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • n_layer (int, optional, defaults to 28) – Number of hidden layers in the Transformer encoder.

  • n_head (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.

  • rotary_dim (int, optional, defaults to 64) – The dimension of the rotary position embedding.

  • n_inner (int, optional) – Dimensionality of the inner feed-forward layers.

  • activation_function (str, optional, defaults to “gelu_new”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • embd_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.

  • attn_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon to use in the layer normalization layers.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • bos_token_id (int, optional, defaults to 50256) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 50256) – The id of the end-of-sequence token.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • gradient_checkpointing (str, optional, defaults to “”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

attach_custom_arguments(vocab_size: int = 50400, n_positions: int = 2048, n_embd: int = 4096, n_layer: int = 28, n_head: int = 16, rotary_dim: int = 64, n_inner: int = None, activation_function: str = 'gelu_new', resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, attn_pdrop: float = 0.0, layer_norm_epsilon: float = 1e-05, initializer_range: int = 0.02, use_cache: int = True, bos_token_id: int = 50256, eos_token_id: int = 50256, tie_word_embeddings: bool = False, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
attribute_map: Dict[str, str] = {'hidden_size': 'n_embd', 'max_position_embeddings': 'n_positions', 'num_attention_heads': 'n_head', 'num_hidden_layers': 'n_layer'}#
static get_mesh_names()[source]#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'gptj'#
class easydel.__init__.GPTJForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.GPTJModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies#

Returns frequency values from the config.

class easydel.__init__.GPTNeoXConfig(vocab_size=50432, hidden_size=6144, num_hidden_layers=44, num_attention_heads=64, intermediate_size=24576, hidden_act='gelu', rotary_pct=0.25, rotary_emb_base=10000, attention_dropout=0.0, hidden_dropout=0.0, classifier_dropout=0.1, max_position_embeddings=2048, initializer_range=0.02, layer_norm_eps=1e-05, use_cache=True, bos_token_id=0, eos_token_id=2, tie_word_embeddings=False, use_parallel_residual=True, rope_scaling=None, attention_bias=True, gradient_checkpointing=EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50432) – Vocabulary size of the GPT NeoX model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 6144) – Dimensionality of the encoder layers and the pooler layer.

  • num_hidden_layers (int, optional, defaults to 44) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 64) – Number of attention heads for each attention layer in the Transformer encoder.

  • intermediate_size (int, optional, defaults to 24576) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • hidden_act (str or function, optional, defaults to “gelu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • rotary_pct (float, optional, defaults to 0.25) – The percentage of hidden dimensions to allocate to rotary embeddings.

  • rotary_emb_base (int, optional, defaults to 10000) – The base for the rotary position embedding.

  • classifier_dropout (float, optional, defaults to 0.1) – The dropout ratio for the classifier layer.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • bos_token_id (int, optional, defaults to 0) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • gradient_checkpointing (str, optional, defaults to “everything_saveable”) – The gradient checkpointing configuration.

  • use_parallel_residual (bool, optional, defaults to True) – Whether to use a parallel residual connection in the attention layer.

attach_custom_arguments(**kwargs)[source]#
static get_mesh_names()[source]#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'gpt_neox'#
class easydel.__init__.GPTNeoXForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.GPTNeoXModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies#

Returns frequency values from the config.

class easydel.__init__.GRPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: ~typing.Optional[str] = None, clip_grad: ~typing.Optional[float] = None, custom_scheduler: ~typing.Optional[~typing.Callable[[int], ~typing.Any]] = None, dataloader_num_workers: ~typing.Optional[int] = 0, dataloader_pin_memory: ~typing.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: ~typing.Optional[int] = None, evaluation_steps: ~typing.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: ~typing.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: ~typing.Optional[~typing.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: ~typing.Optional[dict] = None, learning_rate: float = 5e-05, learning_rate_end: ~typing.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: ~typing.Optional[~easydel.infra.loss_utils.LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: ~typing.Optional[int] = None, max_sequence_length: ~typing.Optional[int] = 4096, max_training_steps: ~typing.Optional[int] = None, model_name: str = 'BaseTrainer', model_parameters: ~typing.Optional[dict] = None, metrics_to_show_in_rich_pbar: ~typing.Optional[~typing.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: ~typing.Literal['adafactor', 'lion', 'adamw', 'rmsprop'] = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: ~typing.Any = None, process_zero_is_admin: bool = True, progress_bar_type: ~typing.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: ~typing.Optional[int] = None, save_total_limit: ~typing.Optional[int] = None, scheduler: ~typing.Literal['linear', 'cosine', 'none'] = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: ~typing.Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo', state_apply_fn_kwarguments_to_model: ~typing.Optional[dict] = None, step_partition_spec: ~jax._src.partition_spec.PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: ~typing.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: ~typing.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: ~typing.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: ~typing.Optional[~numpy.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: ~typing.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0)[source]#

Bases: TrainingArguments

Configuration class for the GRPOTrainer.

beta: float = Field(name=None,type=None,default=0.04,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The beta parameter for GRPO.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
dataset_num_proc: Optional[int] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The number of processes to use for dataset processing.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
learning_rate: float = Field(name=None,type=None,default=1e-06,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The learning rate.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
max_completion_length: int = Field(name=None,type=None,default=256,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The maximum length of the completion.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
max_prompt_length: int = Field(name=None,type=None,default=512,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The maximum length of the prompt.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
model_name: str = Field(name=None,type=None,default='GRPOTrainer',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The name of the model.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
ref_model_mixup_alpha: float = Field(name=None,type=None,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The alpha parameter for mixing the reference model with the policy model.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
ref_model_sync_steps: int = Field(name=None,type=None,default=64,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The number of steps between syncing the reference model.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
remove_unused_columns: Optional[bool] = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to remove unused columns from the dataset.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
replace(**kwargs)#
skip_apply_chat_template: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'whenever to skip extracting prompt from dataset.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
sync_ref_model: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to periodically sync the reference model with the policy model.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
tools: Optional[List[Union[dict, Callable]]] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Additional tools for training.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
class easydel.__init__.GRPOTrainer(arguments: GRPOConfig, vinference: vInference, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]], reward_funcs: Union[EasyDeLBaseModule, EasyDeLState, Callable[[list, list], list[float]], list[Union[easydel.infra.base_module.EasyDeLBaseModule, easydel.infra.base_state.EasyDeLState, Callable[[list, list], list[float]]]]], train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, processing_class: Optional[Any] = None, reward_processing_classes: Optional[Any] = None, data_tokenize_fn: Optional[Callable] = None)[source]#

Bases: Trainer

arguments: GRPOConfig#
configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#

hook process to call in start of the step.

class easydel.__init__.Gemma2Config(vocab_size=256000, hidden_size=3072, intermediate_size=24576, num_hidden_layers=28, num_attention_heads=16, num_key_value_heads=16, head_dim=256, hidden_activation='gelu_pytorch_tanh', max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, tie_word_embeddings=True, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, final_logit_softcapping=30.0, query_pre_attn_scalar=224, sliding_window=4096, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, scan_layers: bool = False, attn_logit_softcapping: Optional[bool] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 256000) – Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 3072) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 24576) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 28) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 16) – Number of key and value heads for each attention layer in the Transformer encoder.

  • head_dim (int, optional, defaults to 256) – Dimensionality of the attention head.

  • hidden_activation (str or function, optional, defaults to “gelu_pytorch_tanh”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 8192) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.

  • eos_token_id (int, optional, defaults to 1) – The index of the end of sequence token in the vocabulary.

  • bos_token_id (int, optional, defaults to 2) – The index of the beginning of sequence token in the vocabulary.

  • tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • attention_bias (bool, optional, defaults to False) – Whether to use attention bias.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • final_logit_softcapping (float, optional, defaults to 30.0) – The soft capping value for the final logits.

  • query_pre_attn_scalar (int, optional, defaults to 224) – The scalar value for the query pre-attention layer.

  • sliding_window (int, optional, defaults to 4096) – The sliding window size.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

  • scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation of the layers.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'gemma2'#
static rng_keys()[source]#
class easydel.__init__.Gemma2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Gemma2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Gemma2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Gemma3Config(text_config: Optional[Gemma3TextConfig] = None, vision_config: Optional[SiglipVisionConfig] = None, mm_tokens_per_image: int = 256, boi_token_index: int = 255999, eoi_token_index: int = 256000, image_token_index: int = 262144, initializer_range: float = 0.02, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Parameters
  • text_config (Union[Gemma3TextConfig, dict], optional) – The config object of the text backbone.

  • vision_config (Union[AutoConfig, dict], optional) – Custom vision config or dict.

  • mm_tokens_per_image (int, optional, defaults to 256) – The number of tokens per image embedding.

  • boi_token_index (int, optional, defaults to 255999) – The begin-of-image token index to wrap the image prompt.

  • eoi_token_index (int, optional, defaults to 256000) – The end-of-image token index to wrap the image prompt.

  • image_token_index (int, optional, defaults to 262144) – The image token index to encode the image prompt.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

Example:

```python >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig

>>> # Initializing a Siglip-like vision config
>>> vision_config = SiglipVisionConfig()
>>> # Initializing a Gemma3 Text config
>>> text_config = Gemma3TextConfig()
>>> # Initializing a Gemma3 gemma-3-4b style configuration
>>> configuration = Gemma3Config(vision_config, text_config)
>>> # Initializing a model from the gemma-3-4b style configuration
>>> model = Gemma3TextConfig(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'gemma3'#
sub_configs: Dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.gemma3.gemma3_configuration.Gemma3TextConfig'>, 'vision_config': <class 'easydel.__init__.modules.siglip.configuration_siglip.SiglipVisionConfig'>}#
class easydel.__init__.Gemma3ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Gemma3ForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

get_image_features(pixel_values: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#
loss_type = 'ForCausalLM'#
prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds

Returns

A dictionary of the past_key_values, attention_mask and position ids

update_inputs_for_generation(model_outputs, model_kwargs)[source]#
class easydel.__init__.Gemma3ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Gemma3MultiModalProjector(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.__init__.Gemma3TextConfig(vocab_size=262208, hidden_size=2304, intermediate_size=9216, num_hidden_layers=26, num_attention_heads=8, num_key_value_heads=4, head_dim=256, hidden_activation='gelu_pytorch_tanh', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, tie_word_embeddings=True, rope_theta=1000000.0, attention_bias=False, attention_dropout=0.0, query_pre_attn_scalar=256, sliding_window=4096, final_logit_softcapping=None, attn_logit_softcapping=None, cache_implementation='hybrid', rope_scaling=None, rope_local_base_freq=10000.0, sliding_window_pattern=6, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, scan_layers: bool = False, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information. :param vocab_size: Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the

inputs_ids passed when calling [Gemma3TextModel]

Parameters
  • hidden_size (int, optional, defaults to 2304) – Dimension of the hidden representations.

  • intermediate_size (int, optional, defaults to 9216) – Dimension of the MLP representations.

  • num_hidden_layers (int, optional, defaults to 26) – Number of hidden layers in the Transformer decoder.

  • num_attention_heads (int, optional, defaults to 8) – Number of attention heads for each attention layer in the Transformer decoder.

  • num_key_value_heads (int, optional, defaults to 4) – This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to num_attention_heads.

  • head_dim (int, optional, defaults to 256) – The attention head dimension.

  • hidden_activation (str or function, optional, defaults to “gelu_pytorch_tanh”) – The non-linear activation function (function or string) in the decoder. Will default to “gelu_pytorch_tanh” if not specified. “gelu_pytorch_tanh” uses an approximation of the “gelu” activation function.

  • max_position_embeddings (int, optional, defaults to 131072) – The maximum sequence length that this model might ever be used with.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-06) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 0) – Padding token id.

  • eos_token_id (int, optional, defaults to 1) – End of stream token id.

  • bos_token_id (int, optional, defaults to 2) – Beginning of stream token id.

  • tie_word_embeddings (bool, optional, defaults to True) – Whether to tie weight embeddings

  • rope_theta (float, optional, defaults to 1000000.0) – The base period of the RoPE embeddings.

  • attention_bias (bool, defaults to False, optional, defaults to False) – Whether to use a bias in the query, key, value and output projection layers during self-attention.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • query_pre_attn_scalar (float, optional, defaults to 256) – Scaling factor used on the attention scores

  • sliding_window (int, optional, defaults to 4096) – in Gemma3Text, every other layer uses sliding window attention. This is the size of the sliding window.

  • final_logit_softcapping (float, optional) – Scaling factor when applying tanh softcapping on the logits.

  • attn_logit_softcapping (float, optional) – Scaling factor when applying tanh softcapping on the attention scores.

  • cache_implementation (str, optional, defaults to “hybrid”) – the cache type to be used with generate.

  • rope_scaling (Dict, optional) –

    Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type and you expect the model to work on longer max_position_embeddings, we recommend you to update this value accordingly. Expected contents:

    rope_type (str):

    The sub-variant of RoPE to use. Can be one of [‘default’, ‘linear’, ‘dynamic’, ‘yarn’, ‘longrope’, ‘llama3’], with ‘default’ being the original RoPE implementation.

    factor (float, optional):

    Used with all rope types except ‘default’. The scaling factor to apply to the RoPE embeddings. In most scaling types, a factor of x will enable the model to handle sequences of length x * original maximum pre-trained length.

    original_max_position_embeddings (int, optional):

    Used with ‘dynamic’, ‘longrope’ and ‘llama3’. The original max position embeddings used during pretraining.

    attention_factor (float, optional):

    Used with ‘yarn’ and ‘longrope’. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the factor field to infer the suggested value.

    beta_fast (float, optional):

    Only used with ‘yarn’. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32.

    beta_slow (float, optional):

    Only used with ‘yarn’. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1.

    short_factor (List[float], optional):

    Only used with ‘longrope’. The scaling factor to be applied to short contexts (< original_max_position_embeddings). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2

    long_factor (List[float], optional):

    Only used with ‘longrope’. The scaling factor to be applied to long contexts (< original_max_position_embeddings). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2

    low_freq_factor (float, optional):

    Only used with ‘llama3’. Scaling factor applied to low frequency components of the RoPE

    high_freq_factor (float, optional):

    Only used with ‘llama3’. Scaling factor applied to high frequency components of the RoPE

  • rope_local_base_freq (float, optional, defaults to 10000.0) – The base period of the RoPE embeddings for local attention.

  • sliding_window_pattern – Pattern for the sliding window attention.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'gemma3_text'#
static rng_keys()[source]#
class easydel.__init__.Gemma3TextModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property default_frequencies#
class easydel.__init__.GemmaConfig(vocab_size=256000, hidden_size=3072, intermediate_size=24576, num_hidden_layers=28, num_attention_heads=16, num_key_value_heads=16, head_dim=256, hidden_act='gelu_pytorch_tanh', max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, tie_word_embeddings=True, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, scan_layers: bool = False, hidden_activation='gelu_pytorch_tanh', **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 256000) – Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 3072) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 24576) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 28) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 16) – Number of key and value heads for each attention layer in the Transformer encoder.

  • head_dim (int, optional, defaults to 256) – Dimensionality of the attention head.

  • hidden_act (str or function, optional, defaults to “gelu_pytorch_tanh”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 8192) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.

  • eos_token_id (int, optional, defaults to 1) – The index of the end of sequence token in the vocabulary.

  • bos_token_id (int, optional, defaults to 2) – The index of the beginning of sequence token in the vocabulary.

  • tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • attention_bias (bool, optional, defaults to False) – Whether to use attention bias.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

  • scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation of the layers.

  • hidden_activation (str, optional) – The hidden activation function to use.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'gemma'#
static rng_keys()[source]#
class easydel.__init__.GemmaForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.GemmaForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.GemmaModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Grok1Config(vocab_size=32000, hidden_size=4096, intermediate_size=32768, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, attn_output_multiplier=1.0, max_attn_value=1.0, max_position_embeddings=4096, embedding_multiplier_scale: float = 1.0, output_multiplier_scale: float = 1.0, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=True, num_experts_per_tok=2, num_experts=8, output_router_logits=False, router_aux_loss_coef=0.001, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the Grok-1 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 32768) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 32) – Number of key and value heads for each attention layer in the Transformer encoder.

  • attn_output_multiplier (float, optional, defaults to 1.0) – The multiplier value applied to the attention output.

  • max_attn_value (float, optional, defaults to 1.0) – The maximum value of the attention weights.

  • max_position_embeddings (int, optional, defaults to 4096) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • embedding_multiplier_scale (float, optional, defaults to 1.0) – The scale factor for the embedding layer.

  • output_multiplier_scale (float, optional, defaults to 1.0) – The scale factor for the output layer.

  • rms_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 1) – The index of the beginning of sequence token in the vocabulary.

  • eos_token_id (int, optional, defaults to 2) – The index of the end of sequence token in the vocabulary.

  • tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.

  • num_experts_per_tok (int, optional, defaults to 2) – The number of experts per token.

  • num_experts (int, optional, defaults to 8) – The number of experts.

  • output_router_logits (bool, optional, defaults to False) – Whether to output router logits.

  • router_aux_loss_coef (float, optional, defaults to 0.001) – The router auxiliary loss coefficient.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

attach_custom_arguments(tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • tie_word_embeddings – bool: Tie the word embeddings to the decoder

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'grok-1'#
static rng_keys()[source]#
class easydel.__init__.Grok1ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Grok1Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies#

Returns frequency values from the config.

class easydel.__init__.InternLM2Config(vocab_size=103168, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=0, bos_token_id=1, eos_token_id=2, pretraining_tp=1, tie_word_embeddings=False, bias=True, rope_theta=10000, rope_scaling=None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = -1, fcm_max_ratio: float = -1, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, scan_layers: bool = False, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 11008) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to number_rep_kv * num_attention_heads if not set.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 0) – The id of the pad token.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • bias (bool, optional, defaults to False) – Whether to use attention bias.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • fcm_min_ratio (float, optional, defaults to -1) – The minimum ratio for Flash Attention.

  • fcm_max_ratio (float, optional, defaults to -1) – The maximum ratio for Flash Attention.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • bits (int, optional) – The number of bits to quantize the model to.

  • hidden_act (str, optional, defaults to “silu”) – The hidden activation function to use.

  • pretraining_tp (int, optional, defaults to 1) – The tensor parallelism degree used during pretraining.

  • mlp_bias (bool, optional, defaults to False) – Whether to use bias in the MLP.

  • scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation for the layers.

attach_custom_arguments(tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = 0.0, fcm_max_ratio: float = 0.0, bits: Optional[int] = None, rope_theta: float = 10000.0, hidden_act: str = 'silu', scan_layers: bool = True, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • tie_word_embeddings – bool: Tie the word embeddings to the decoder

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • fcm_min_ratio – float: Control the minimum ratio of the number of chunks to be used in flash-based computation

  • fcm_max_ratio – float: Set the maximum ratio of the number of input tokens to output tokens

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

  • rope_theta – float : rope_theta for compute rope

  • hidden_act – str : hidden_act for mlp

  • scan_layers – bool: Determine whether to use scan layers or not

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'internlm2'#
static rng_keys()[source]#
class easydel.__init__.InternLM2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.InternLM2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.InternLM2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.JaxDistributedConfig[source]#

Bases: object

From EasyLM Utility class for initializing JAX distributed.

static get_default_config(updates=None)[source]#
classmethod initialize(config=None)[source]#
class easydel.__init__.LlamaConfig(vocab_size: int = 32000, hidden_size: int = 4096, intermediate_size: int = 11008, num_hidden_layers: int = 32, num_attention_heads: int = 32, number_rep_kv: int = 1, head_dim: Optional[int] = None, num_key_value_heads: Optional[int] = None, max_position_embeddings: int = 2048, rms_norm_eps: float = 1e-06, initializer_range: float = 0.02, use_cache: bool = True, bos_token_id: int = 0, eos_token_id: int = 1, resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, attention_dropout: float = 0.0, rope_theta: float = 10000.0, attention_bias: bool = False, tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = -1, fcm_max_ratio: float = -1, rope_scaling: Dict[str, Union[str, float]] = None, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, hidden_act: str = 'silu', pretraining_tp: int = 1, mlp_bias: bool = False, scan_layers: bool = False, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the Llama model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 11008) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • number_rep_kv (int, optional, defaults to 1) – Number of repetitions for the key and value vectors.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to number_rep_kv * num_attention_heads if not set.

  • max_position_embeddings (int, optional, defaults to 2048) –

    The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

    head_dim (int, optional):

    head_dim for attention qkv.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • bos_token_id (int, optional, defaults to 0) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 1) – The id of the end-of-sequence token.

  • resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • embd_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • attention_bias (bool, optional, defaults to False) – Whether to use attention bias.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • fcm_min_ratio (float, optional, defaults to -1) – The minimum ratio for Flash Attention.

  • fcm_max_ratio (float, optional, defaults to -1) – The maximum ratio for Flash Attention.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • bits (int, optional) – The number of bits to quantize the model to.

  • hidden_act (str, optional, defaults to “silu”) – The hidden activation function to use.

  • pretraining_tp (int, optional, defaults to 1) – The tensor parallelism degree used during pretraining.

  • mlp_bias (bool, optional, defaults to False) – Whether to use bias in the MLP.

  • scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation for the layers.

attach_custom_arguments(resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, attention_dropout: float = 0.0, tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = 0.0, fcm_max_ratio: float = 0.0, number_rep_kv: int = 1, bits: Optional[int] = None, rope_theta: float = 10000.0, attention_bias: bool = False, hidden_act: str = 'silu', scan_layers: bool = True, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • resid_pdrop – float: Set the dropout rate for residual connections

  • embd_pdrop – float: Set the probability of dropping an embedding

  • attention_dropout – float: Set the probability of dropping out the attention layer

  • tie_word_embeddings – bool: Tie the word embeddings to the decoder

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • fcm_min_ratio – float: Control the minimum ratio of the number of chunks to be used in flash-based computation

  • fcm_max_ratio – float: Set the maximum ratio of the number of input tokens to output tokens

  • number_rep_kv – int: Determine how many times the key and value vectors are repeated

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

  • rope_theta – float : rope_theta for compute rope

  • attention_bias – bool : whenever to use attention bias or no

  • hidden_act – str : hidden_act for mlp

  • scan_layers – bool: Determine whether to use scan layers or not

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'llama'#
static rng_keys()[source]#
class easydel.__init__.LlamaForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

config#

Configuration for the attention module.

Type

LlamaConfig

dtype#

Data type for computations (default is jnp.bfloat16).

Type

jnp.dtype

param_dtype#

Data type for parameters (default is jnp.bfloat16).

Type

jnp.dtype

precision#

Precision setting for JAX operations (default is “fastest”).

Type

tp.Optional[tp.Union[str, jax.lax.Precision]]

class easydel.__init__.LlamaForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.LlamaModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.LlavaConfig(vision_config=None, text_config=None, image_token_index=32000, projector_hidden_act='gelu', vision_feature_select_strategy='default', vision_feature_layer=-2, image_seq_length=576, multimodal_projector_bias=True, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the configuration of a [LlavaForConditionalGeneration]. It is used to instantiate an Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Llava-9B.

e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)

Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.

Parameters
  • vision_config (Union[AutoConfig, dict], optional, defaults to CLIPVisionConfig) – The config object or dictionary of the vision backbone.

  • text_config (Union[AutoConfig, dict], optional, defaults to LlamaConfig) – The config object or dictionary of the text backbone.

  • image_token_index (int, optional, defaults to 32000) – The image token index to encode the image prompt.

  • projector_hidden_act (str, optional, defaults to “gelu”) – The activation function used by the multimodal projector.

  • vision_feature_select_strategy (str, optional, defaults to “default”) – The feature selection strategy used to select the vision feature from the vision backbone. Can be one of “default” or “full”.

  • vision_feature_layer (Union[int, List[int]], optional, defaults to -2) – The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features.

  • image_seq_length (int, optional, defaults to 576) – Sequence length of one image embedding.

  • multimodal_projector_bias (bool, optional, defaults to True) – Whether to use bias in the multimodal projector.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'llava'#
sub_configs: Dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.auto.auto_configuration.AutoEasyDeLConfig'>, 'vision_config': <class 'easydel.__init__.modules.auto.auto_configuration.AutoEasyDeLConfig'>}#
class easydel.__init__.LlavaForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

get_image_features(pixel_values: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#
loss_type = 'ForCausalLM'#
prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds

Returns

A dictionary of the past_key_values, attention_mask and position ids

update_inputs_for_generation(model_outputs, model_kwargs)[source]#
class easydel.__init__.LossConfig(ignore_index: int = -100, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Union[float, int, str, easydel.__init__.infra.loss_utils.SpecialLossNormalizingFactor, NoneType] = 'NUM_REAL_TARGET_TOKENS', num_labels: Optional[str] = None, problem_type: Optional[str] = None, divide_weight_sum: bool = False, shift_tokens: bool = True, break_on_nan: bool = True, reduction: Optional[Literal['none', 'mean', 'sum']] = None, num_classification_labels: Optional[int] = None, classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None)[source]#

Bases: Mapping

break_on_nan: bool = True#
classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None#
divide_weight_sum: bool = False#
from_tuple()#
ignore_index: int = -100#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
label_smoothing: float = 0.0#
loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]] = 'NUM_REAL_TARGET_TOKENS'#
num_classification_labels: Optional[int] = None#
num_labels: Optional[str] = None#
problem_type: Optional[str] = None#
reduction: Optional[Literal['none', 'mean', 'sum']] = None#
replace(**kwargs)#
shift_tokens: bool = True#
to_tuple()#
values() an object providing a view on D's values#
z_loss: float = 0.0#
class easydel.__init__.Mamba2Config(num_heads=128, head_dim=64, vocab_size=32768, hidden_size=4096, state_size=128, num_hidden_layers=64, layer_norm_epsilon=1e-05, pad_token_id=1, bos_token_id=0, eos_token_id=2, expand=2, conv_kernel=4, n_groups=8, use_bias=False, use_conv_bias=True, hidden_act='silu', initializer_range=0.1, residual_in_fp32=True, time_step_rank='auto', time_step_min=0.001, time_step_max=0.1, time_step_floor=0.0001, time_step_limit=(0.0, inf), rescale_prenorm_residual=False, use_cache=True, norm_before_gate=True, rms_norm=True, chunk_size=256, tie_word_embeddings=False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50280) – Vocabulary size of the Mamba model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.

  • state_size (int, optional, defaults to 16) – State size of the Mamba model.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 0) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 0) – The id of the end-of-sequence token.

  • expand (int, optional, defaults to 2) – Expansion factor for the intermediate size.

  • conv_kernel (int, optional, defaults to 4) – Kernel size of the convolution layer.

  • use_bias (bool, optional, defaults to False) – Whether to use bias in the linear layers.

  • use_conv_bias (bool, optional, defaults to True) – Whether to use bias in the convolution layer.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • initializer_range (float, optional, defaults to 0.1) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • residual_in_fp32 (bool, optional, defaults to True) – Whether to compute the residual connection in float32.

  • time_step_rank (str or int, optional, defaults to “auto”) – The rank of the time step embedding. If set to “auto”, the rank is calculated as math.ceil(self.hidden_size / 16).

  • time_step_scale (float, optional, defaults to 1.0) – The scale factor for the time step embedding.

  • time_step_min (float, optional, defaults to 0.001) – The minimum value for the time step embedding.

  • time_step_max (float, optional, defaults to 0.1) – The maximum value for the time step embedding.

  • time_step_floor (float, optional, defaults to 1e-4) – The floor value for the time step embedding.

  • rescale_prenorm_residual (bool, optional, defaults to False) – Whether to rescale the pre-norm residual.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE)[source]#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'mamba2'#
class easydel.__init__.Mamba2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

init_cache(batch_size: int, max_length: int)[source]#
prepare_inputs_for_generation(input_ids, inputs_embeds=None, cache_params: Optional[Mamba2Cache] = None, cache_position: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, **kwargs)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds

Returns

A dictionary of the past_key_values, attention_mask and position ids

update_inputs_for_generation(model_outputs, model_kwargs)[source]#
class easydel.__init__.Mamba2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.MambaConfig(vocab_size=50280, hidden_size=768, state_size=16, num_hidden_layers=32, layer_norm_epsilon=1e-05, pad_token_id=0, bos_token_id=0, eos_token_id=0, expand=2, conv_kernel=4, use_bias=False, use_conv_bias=True, hidden_act='silu', initializer_range=0.1, residual_in_fp32=True, time_step_rank='auto', time_step_scale=1.0, time_step_min=0.001, time_step_max=0.1, time_step_init_scheme='random', time_step_floor=0.0001, rescale_prenorm_residual=False, use_cache=True, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_mambapy: bool = False, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50280) – Vocabulary size of the Mamba model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.

  • state_size (int, optional, defaults to 16) – State size of the Mamba model.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 0) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 0) – The id of the end-of-sequence token.

  • expand (int, optional, defaults to 2) – Expansion factor for the intermediate size.

  • conv_kernel (int, optional, defaults to 4) – Kernel size of the convolution layer.

  • use_bias (bool, optional, defaults to False) – Whether to use bias in the linear layers.

  • use_conv_bias (bool, optional, defaults to True) – Whether to use bias in the convolution layer.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • initializer_range (float, optional, defaults to 0.1) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • residual_in_fp32 (bool, optional, defaults to True) – Whether to compute the residual connection in float32.

  • time_step_rank (str or int, optional, defaults to “auto”) – The rank of the time step embedding. If set to “auto”, the rank is calculated as math.ceil(self.hidden_size / 16).

  • time_step_scale (float, optional, defaults to 1.0) – The scale factor for the time step embedding.

  • time_step_min (float, optional, defaults to 0.001) – The minimum value for the time step embedding.

  • time_step_max (float, optional, defaults to 0.1) – The maximum value for the time step embedding.

  • time_step_init_scheme (str, optional, defaults to “random”) – The initialization scheme for the time step embedding. Possible values are “random” and “uniform”.

  • time_step_floor (float, optional, defaults to 1e-4) – The floor value for the time step embedding.

  • rescale_prenorm_residual (bool, optional, defaults to False) – Whether to rescale the pre-norm residual.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE)[source]#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'mamba'#
class easydel.__init__.MambaForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

init_cache(batch_size: int, max_length: int)[source]#
prepare_inputs_for_generation(input_ids, max_length, **kwargs)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds

Returns

A dictionary of the past_key_values, attention_mask and position ids

update_inputs_for_generation(outputs: MambaOutput, model_kwargs: Dict[str, Any], **kwargs) Dict[str, Any][source]#
class easydel.__init__.MambaModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

init_cache(batch_size: int, max_length: int)[source]#
class easydel.__init__.MistralConfig(vocab_size: int = 32000, hidden_size: int = 4096, intermediate_size: int = 14336, head_dim: int = 128, num_hidden_layers: int = 32, num_attention_heads: int = 32, num_key_value_heads: int = 8, hidden_act='silu', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=None, bos_token_id: int = 1, eos_token_id: int = 2, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling: Dict[str, Union[str, float]] = None, sliding_window=4096, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, number_rep_kv: int = 1, attention_dropout: float = 0.0, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, attention_bias: bool = False, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 14336) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • head_dim (int, defaults to 128) – Dimensionality of the head for attention.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 8) – Number of key and value heads for each attention layer in the Transformer encoder.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 4096 * 32) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • sliding_window (int, optional, defaults to 4096) – The sliding window size.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • number_rep_kv (int, optional, defaults to 1) – Number of repetitions for the key and value vectors.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • bits (int, optional) – The number of bits to quantize the model to.

  • attention_bias (bool, optional, defaults to False) – Whether to use bias in the attention layer.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, attention_dropout: float = 0.0, rope_scaling: Dict[str, Union[str, float]] = None, attention_bias: bool = False, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the model:

Parameters
  • self – Bind the attributes and methods of a class to an instance of that class

  • gradient_checkpointing – str: Determine whether to use gradient checkpointing

  • use_scan_mlp – bool: Determine whether to use the scan_mlp function or notn

  • scan_mlp_chunk_size – int: Chunk the input to the mlp

  • number_rep_kv – int: Control the number of times that the key and value vectors are repeated

  • bits – tp.Optional[int]: Specify the number of bits to use for quantization

  • attention_dropout – float: Set the dropout rate for the attention layer

  • attention_bias – bool: when ever to use attention_bias

  • rope_scaling – tp.Dict[str, tp.Union[str, float]]: rope_scaling for rope

Return type

A tuple of the following

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'mistral'#
static rng_keys()[source]#
class easydel.__init__.MistralForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.MistralForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.MistralModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.MixtralConfig(vocab_size=32000, hidden_size=4096, intermediate_size=14336, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, hidden_act='silu', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=1000000.0, sliding_window=4096, attention_dropout=0.0, num_experts_per_tok=2, num_local_experts=8, output_router_logits=False, router_aux_loss_coef=0.001, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, rope_scaling: Dict[str, Union[str, float]] = None, attention_bias: bool = False, initialization_of_moe: bool = False, router_jitter_noise=0.0, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the Mixtral model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 14336) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 8) – Number of key and value heads for each attention layer in the Transformer encoder.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 4096 * 32) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 1e6) – The theta value to use for rotary position embeddings.

  • sliding_window (int, optional, defaults to 4096) – The sliding window size.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • num_experts_per_tok (int, optional, defaults to 2) – The number of experts per token.

  • num_local_experts (int, optional, defaults to 8) – The number of local experts.

  • output_router_logits (bool, optional, defaults to False) – Whether to output router logits.

  • router_aux_loss_coef (float, optional, defaults to 0.001) – The router auxiliary loss coefficient.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • number_rep_kv (int, optional, defaults to 1) – Number of repetitions for the key and value vectors.

  • bits (int, optional) – The number of bits to quantize the model to.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • attention_bias (bool, optional, defaults to False) – Whether to use bias in the attention layer.

  • initialization_of_moe (bool, optional, defaults to False) – Whether to initialize the MoE layers.

  • router_jitter_noise (float, optional, defaults to 0.0) – The jitter noise for the router.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, attention_dropout: float = 0.0, rope_scaling: Dict[str, Union[str, float]] = None, attention_bias: bool = False, initialization_of_moe: bool = False, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the model:

Parameters
  • self – Bind the attributes and methods of a class to an instance of that class

  • gradient_checkpointing – str: Determine whether to use gradient checkpointing

  • use_scan_mlp – bool: Determine whether to use the scan_mlp function or not

  • scan_mlp_chunk_size – int: Chunk the input to the mlp

  • number_rep_kv – int: Control the number of times that the key and value vectors are repeated

  • bits – tp.Optional[int]: Specify the number of bits to use for quantization

  • attention_dropout – float: Set the dropout rate for the attention layer

  • attention_bias – bool: when ever to use attention_bias

  • initialization_of_moe – bool: initialization of moe needs to disable some dynamic part’s this boolean variable will turn them off.

  • rope_scaling – tp.Dict[str, tp.Union[str, float]]: rope_scaling for rope

Return type

A tuple of the following

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'mixtral'#
static rng_keys()[source]#
class easydel.__init__.MixtralForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.MixtralForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.MixtralModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.MptAttentionConfig(attn_type='multihead_attention', attn_pdrop=0, attn_impl='torch', clip_qkv=None, softmax_scale=None, prefix_lm=False, qk_ln=False, attn_uses_sequence_id=False, alibi=True, alibi_bias_max=8, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the attention related configuration of a [MptModel].

Parameters
  • attn_type (str, optional, defaults to “multihead_attention”) – The type of attention to use. Can be either “multihead_attention” or “multiquery_attention”.

  • attn_pdrop (float, optional, defaults to 0.0) – The dropout probability applied to the attention output.

  • attn_impl (str, optional, defaults to “torch”) – The implementation of the attention mechanism. Can be either “torch” or “flash”.

  • clip_qkv (float, optional) – The clip value applied to the query, key, and value tensors.

  • softmax_scale (float, optional) – The scale factor applied to the softmax function in the attention layer.

  • prefix_lm (bool, optional, defaults to False) – Whether to use a prefix LM.

  • qk_ln (bool, optional, defaults to False) – Whether to apply layer normalization to the query and key tensors.

  • attn_uses_sequence_id (bool, optional, defaults to False) – Whether the attention layer uses sequence IDs.

  • alibi (bool, optional, defaults to True) – Whether to use the ALiBi (Attention with Linear Biases) method.

  • alibi_bias_max (int, optional, defaults to 8) – The maximum value for the ALiBi bias.

classmethod from_pretrained(pretrained_model_name_or_path, **kwargs) EasyDeLBaseConfig[source]#

Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.

Parameters
  • pretrained_model_name_or_path (str or os.PathLike) –

    This can be either:

    • a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.

    • a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.

    • a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.

  • cache_dir (str or os.PathLike, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.

  • force_download (bool, optional, defaults to False) – Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.

  • resume_download – Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.

  • proxies (Dict[str, str], optional) – A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.

  • token (str or bool, optional) – The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).

  • revision (str, optional, defaults to “main”) –

    The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.

    <Tip>

    To test a pull request you made on the Hub, you can pass `revision=”refs/pr/<pr_number>”.

    </Tip>

  • return_unused_kwargs (bool, optional, defaults to False) –

    If False, then this function returns just the final configuration object.

    If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.

  • subfolder (str, optional, defaults to “”) – In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.

  • kwargs (Dict[str, tp.Any], optional) – The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.

Returns

The configuration object instantiated from this pretrained model.

Return type

[PretrainedConfig]

Examples:

>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
>>> # derived class: BertConfig
>>> config = BertConfig.from_pretrained(
...   "google-bert/bert-base-uncased"
>>> )  # Download configuration from huggingface.co and cache.
>>> config = BertConfig.from_pretrained(
...   "./test/saved_model/"
>>> )  # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
>>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
>>> config = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased", output_attentions=True, foo=False
>>> )
>>> assert config.output_attentions == True
>>> config, unused_kwargs = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased",
...  output_attentions=True,
...  foo=False,
...  return_unused_kwargs=True,
>>> )
>>> assert config.output_attentions == True
>>> assert unused_kwargs == {"foo": False}

```

class easydel.__init__.MptConfig(d_model: int = 2048, n_heads: int = 16, n_layers: int = 24, expansion_ratio: int = 4, max_seq_len: int = 2048, vocab_size: int = 50368, resid_prob_drop: float = 0.0, layer_norm_epsilon: float = 1e-05, emb_prob_drop: float = 0.0, learned_pos_emb: bool = True, attn_config: Optional[MptAttentionConfig] = None, init_device: str = 'cpu', logit_scale: Optional[Union[float, str]] = None, no_bias: bool = True, verbose: int = 0, embedding_fraction: float = 1.0, norm_type: str = 'low_precision_layernorm', use_cache: bool = False, initializer_range=0.02, alibi: bool = True, use_bias: bool = False, act_fn: str = 'gelu', qk_ln: bool = False, use_lm_head: bool = False, use_norm_bias: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • d_model (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.

  • n_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.

  • n_layers (int, optional, defaults to 24) – Number of hidden layers in the Transformer encoder.

  • expansion_ratio (int, optional, defaults to 4) – Expansion ratio of the feed-forward layer.

  • max_seq_len (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • vocab_size (int, optional, defaults to 50368) – Vocabulary size of the MPT model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • resid_prob_drop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • layer_norm_epsilon (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • emb_prob_drop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.

  • learned_pos_emb (bool, optional, defaults to True) – Whether to learn positional embeddings.

  • attn_config ([MptAttentionConfig], optional) – The configuration of the attention layer.

  • init_device (str, optional, defaults to “cpu”) – The device to initialize the model on.

  • logit_scale (float or str, optional) – The logit scale. If set to “inv_sqrt_d_model”, the logit scale is calculated as 1 / math.sqrt(d_model).

  • no_bias (bool, optional, defaults to True) – Whether to use bias in the linear layers.

  • verbose (int, optional, defaults to 0) – The verbosity level.

  • embedding_fraction (float, optional, defaults to 1.0) – The fraction of the embedding matrix to use.

  • norm_type (str, optional, defaults to “low_precision_layernorm”) – The type of layer normalization to use.

  • use_cache (bool, optional, defaults to False) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • alibi (bool, optional, defaults to True) – Whether to use ALiBi (Attention with Linear Biases) method.

  • use_bias (bool, optional, defaults to False) – Whether to use bias in the linear layers.

  • act_fn (str, optional, defaults to “gelu”) – The activation function to use.

  • qk_ln (bool, optional, defaults to False) – Whether to apply layer normalization to the query and key tensors.

  • use_lm_head (bool, optional, defaults to False) – Whether to use a language modeling head.

  • use_norm_bias (bool, optional, defaults to False) – Whether to use bias in the layer normalization layers.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#
attribute_map: Dict[str, str] = {'hidden_size': 'd_model', 'max_position_embeddings': 'max_seq_len', 'num_attention_heads': 'n_heads', 'num_hidden_layers': 'n_layers', 'tie_word_embeddings': 'use_lm_head'}#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'mpt'#
class easydel.__init__.MptForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.MptModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property alibi#
class easydel.__init__.OPTConfig(vocab_size: int = 50272, hidden_size: int = 768, num_hidden_layers: int = 12, ffn_dim: int = 3072, max_position_embeddings: int = 2048, do_layer_norm_before: bool = True, _remove_final_layer_norm: bool = False, word_embed_proj_dim: int = None, dropout: float = 0.1, attention_dropout: float = 0.0, num_attention_heads: int = 12, activation_function: str = 'relu', layerdrop: float = 0.0, init_std: float = 0.02, use_cache: bool = True, pad_token_id: int = 1, bos_token_id: int = 2, eos_token_id: int = 2, enable_bias: bool = True, layer_norm_elementwise_affine: bool = True, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50272) – Vocabulary size of the OPT model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.

  • num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • ffn_dim (int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).

  • do_layer_norm_before (bool, optional, defaults to True) – Whether to perform layer normalization before the attention block.

  • _remove_final_layer_norm (bool, optional, defaults to False) – Whether to remove the final layer norm.

  • word_embed_proj_dim (int, optional) – The dimension of the word embedding projection. If not provided, it will default to hidden_size.

  • dropout (float, optional, defaults to 0.1) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • num_attention_heads (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.

  • activation_function (str or function, optional, defaults to “relu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • layerdrop (float, optional, defaults to 0.0) – The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more details.

  • init_std (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 1) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 2) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • enable_bias (bool, optional, defaults to True) – Whether to use bias in the linear layers.

  • layer_norm_elementwise_affine (bool, optional, defaults to True) – Whether to use elementwise affine in the layer normalization layers.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

attach_custom_arguments(vocab_size: int = 50272, hidden_size: int = 768, num_hidden_layers: int = 12, ffn_dim: int = 3072, max_position_embeddings: int = 2048, do_layer_norm_before: bool = True, _remove_final_layer_norm: bool = False, word_embed_proj_dim: int = None, dropout: float = 0.1, attention_dropout: float = 0.0, num_attention_heads: int = 12, activation_function: str = 'relu', layerdrop: float = 0.0, init_std: float = 0.02, use_cache: bool = True, pad_token_id: int = 1, bos_token_id: int = 2, eos_token_id: int = 2, enable_bias: bool = True, layer_norm_elementwise_affine: bool = True, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

keys_to_ignore_at_inference = ['past_key_values']#
model_type: str = 'opt'#
class easydel.__init__.OPTForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

get_decoder()[source]#
get_input_embeddings()[source]#
get_output_embeddings()[source]#
prepare_inputs_for_generation(input_ids, max_length, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds

Returns

A dictionary of the past_key_values, attention_mask and position ids

set_decoder(decoder)[source]#
set_input_embeddings(value)[source]#
set_output_embeddings(new_embeddings)[source]#
update_inputs_for_generation(model_outputs, model_kwargs)[source]#
class easydel.__init__.OPTModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

get_input_embeddings()[source]#
set_input_embeddings(value)[source]#
class easydel.__init__.ORPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: ~typing.Optional[str] = None, clip_grad: ~typing.Optional[float] = None, custom_scheduler: ~typing.Optional[~typing.Callable[[int], ~typing.Any]] = None, dataloader_num_workers: ~typing.Optional[int] = 0, dataloader_pin_memory: ~typing.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: ~typing.Optional[int] = None, evaluation_steps: ~typing.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: ~typing.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: ~typing.Optional[~typing.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: ~typing.Optional[dict] = None, learning_rate: float = 5e-05, learning_rate_end: ~typing.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: ~typing.Optional[~easydel.infra.loss_utils.LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: ~typing.Optional[int] = None, max_sequence_length: ~typing.Optional[int] = 4096, max_training_steps: ~typing.Optional[int] = None, model_name: str = 'BaseTrainer', model_parameters: ~typing.Optional[dict] = None, metrics_to_show_in_rich_pbar: ~typing.Optional[~typing.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: ~typing.Literal['adafactor', 'lion', 'adamw', 'rmsprop'] = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: ~typing.Any = None, process_zero_is_admin: bool = True, progress_bar_type: ~typing.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: ~typing.Optional[int] = None, save_total_limit: ~typing.Optional[int] = None, scheduler: ~typing.Literal['linear', 'cosine', 'none'] = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: ~typing.Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo', state_apply_fn_kwarguments_to_model: ~typing.Optional[dict] = None, step_partition_spec: ~jax._src.partition_spec.PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: ~typing.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: ~typing.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: ~typing.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: ~typing.Optional[~numpy.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: ~typing.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0)[source]#

Bases: TrainingArguments

Configuration class for ORPO training settings.

This class inherits from TrainingArguments and holds configuration parameters specific to the ORPO model training. The dataclass automatically generates an initializer, and the __post_init__ method further processes some of the parameters after object initialization.

model_name#

The name of the model. Default is “ORPOTrainer”.

Type

str

learning_rate#

The learning rate used during training. Default is 1e-6.

Type

float

max_length#

The maximum allowed sequence length for the input. Default is 1024.

Type

Optional[int]

max_prompt_length#

The maximum allowed length of the prompt portion of the input. Default is 512.

Type

Optional[int]

max_completion_length#

The maximum allowed length of the completion. If not provided, it is set to max_length - max_prompt_length.

Type

Optional[int]

beta#

A hyperparameter beta, with a default value of 0.1.

Type

float

disable_dropout#

Flag to disable dropout during training. Default is True.

Type

bool

label_pad_token_id#

The token id used for padding labels. Default is -100.

Type

int

padding_value#

The value used for padding sequences. Default is None.

Type

Optional[int]

generate_during_eval#

Flag indicating whether to generate sequences during evaluation. Default is False.

Type

bool

is_encoder_decoder#

Flag to indicate if the model is encoder-decoder. Default is None.

Type

Optional[bool]

model_init_kwargs#

Additional keyword arguments for model initialization. Default is None.

Type

Optional[Dict[str, Any]]

dataset_num_proc#

Number of processes to use for dataset processing. Default is None.

Type

Optional[int]

max_sequence_length#

Computed attribute representing the maximum sequence length used for training. It is set in the __post_init__ method.

Type

int

beta: float = Field(name=None,type=None,default=0.1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'A hyperparameter beta.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
dataset_num_proc: Optional[int] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of processes to use for dataset processing.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
disable_dropout: bool = Field(name=None,type=None,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Flag to disable dropout during training.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
generate_during_eval: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Flag indicating whether to generate sequences during evaluation.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
is_encoder_decoder: Optional[bool] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Flag to indicate if the model is encoder-decoder.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
label_pad_token_id: int = Field(name=None,type=None,default=-100,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The token id used for padding labels.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
learning_rate: float = Field(name=None,type=None,default=1e-06,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The learning rate used during training.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
max_completion_length: Optional[int] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The maximum allowed length of the completion. If not provided, it is set to max_length - max_prompt_length.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
max_length: Optional[int] = Field(name=None,type=None,default=1024,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The maximum allowed sequence length for the input.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
max_prompt_length: Optional[int] = Field(name=None,type=None,default=512,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The maximum allowed length of the prompt portion of the input.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
model_name: str = Field(name=None,type=None,default='ORPOTrainer',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The name of the model.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
padding_value: Optional[int] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The value used for padding sequences.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
replace(**kwargs)#
class easydel.__init__.ORPOTrainer(arguments: ORPOConfig, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, data_collator: Optional[DPODataCollatorWithPadding] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, processing_class: Optional[Any] = None)[source]#

Bases: Trainer

arguments: ORPOConfig#
build_tokenized_answer(prompt: str, answer: str) Dict[str, ndarray][source]#

Tokenizes a prompt and answer pair, handling special tokens and padding/truncation.

Parameters
  • prompt (str) – The prompt text.

  • answer (str) – The answer text.

Returns

A dictionary containing the tokenized prompt and answer, along with attention masks.

Return type

tp.Dict[str, np.ndarray]

Raises

ValueError – If there’s a mismatch in token lengths.

configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Creates a data collection function for batching.

For DPO training, this method simply returns the pre-configured data_collator.

Parameters
  • max_sequence_length (int) – The maximum sequence length (not used in this implementation).

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.

Returns

The data collator function.

Return type

tp.Callable

tokenize_row(feature: Dict[str, str], state: Optional[object] = None) Dict[str, ndarray][source]#

Tokenizes a single row of data from the ORPO dataset.

This method tokenizes the prompt, chosen response, and rejected response, handles padding and truncation, and prepares the data for input to the DPO model.

Parameters
  • feature (tp.Dict) – A dictionary containing the “prompt”, “chosen”, and “rejected” texts.

  • state (EasyDeLState, optional) – Not used in this implementation. Defaults to None.

Returns

A dictionary containing the tokenized prompt, chosen response, and rejected response,

along with attention masks and labels.

Return type

tp.Dict

Raises

ValueError – If the input data types are incorrect.

class easydel.__init__.Olmo2Config(vocab_size=50304, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=2048, initializer_range=0.02, use_cache=True, pad_token_id=1, bos_token_id=None, eos_token_id=50279, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, rms_norm_eps=1e-05, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the configuration of a [Olmo2Model]. It is used to instantiate an OLMo2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf).

Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50304) – Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [Olmo2Model]

  • hidden_size (int, optional, defaults to 4096) – Dimension of the hidden representations.

  • intermediate_size (int, optional, defaults to 11008) – Dimension of the MLP representations.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer decoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer decoder.

  • num_key_value_heads (int, optional) – This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to num_attention_heads.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) in the decoder.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 1) – Padding token id.

  • bos_token_id (int, optional) – Beginning of stream token id.

  • eos_token_id (int, optional, defaults to 50279) – End of stream token id.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie weight embeddings

  • rope_theta (float, optional, defaults to 10000.0) – The base period of the RoPE embeddings.

  • rope_scaling (tp.Dict, optional) – Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is {“type”: strategy name, “factor”: scaling factor}. When using this flag, don’t update max_position_embeddings to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions.

  • attention_bias (bool, defaults to False, optional, defaults to False) – Whether to use a bias in the query, key, value and output projection layers during self-attention.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • rms_norm_eps (float, optional, defaults to 1e-05) – The epsilon used by the rms normalization layers.

>>> from transformers import Olmo2Model, Olmo2Config
>>> # Initializing a Olmo2 7B style configuration
>>> configuration = Olmo2Config()
>>> # Initializing a model from the Olmo2 7B style configuration
>>> model = Olmo2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None)[source]#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
keys_to_ignore_at_inference = ['past_key_values']#
model_type: str = 'olmo2'#
class easydel.__init__.Olmo2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Olmo2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Olmo2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.OlmoConfig(vocab_size=50304, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act='silu', max_position_embeddings=2048, initializer_range=0.02, use_cache=True, pad_token_id=1, bos_token_id=None, eos_token_id=50279, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, clip_qkv=None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50304) – Vocabulary size of the Olmo model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 11008) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 1) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 50279) – The id of the end-of-sequence token.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • attention_bias (bool, optional, defaults to False) – Whether to use attention bias.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • clip_qkv (float, optional) – The clip value applied to the query, key, and value tensors.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • bits (int, optional) – The number of bits to quantize the model to.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None)[source]#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'olmo'#
class easydel.__init__.OlmoForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.OlmoModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.OpenELMConfig(vocab_size: int = 32000, max_context_length: int = 2048, num_transformer_layers: int = 12, model_dim: int = 2048, head_dim: int = 128, qkv_multipliers: Union[Number, List[Number]] = 1.0, num_query_heads: Optional[int] = None, num_gqa_groups: int = 1, ffn_multipliers: Union[Number, List[Number]] = 4.0, ffn_with_glu: bool = True, ffn_dim_divisor: int = 256, activation_fn_name: str = 'swish', normalization_layer_name: str = 'rms_norm', normalize_qk_projections: bool = False, share_input_output_layers: bool = False, rope_freq_constant: int = 10000, rope_max_length: int = 4096, initializer_range: float = 0.02, use_cache: bool = True, bos_token_id: int = 1, eos_token_id: int = 2, rope_scaling: Dict[str, Union[str, float]] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the OpenELM model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • max_context_length (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • num_transformer_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • model_dim (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.

  • head_dim (int, optional, defaults to 128) – Dimensionality of the attention heads.

  • qkv_multipliers (float or list of float, optional, defaults to 1.0) – The multiplier for the query, key, and value projections.

  • num_query_heads (int, optional) – Number of query heads. If not provided, it will be calculated based on model_dim and head_dim.

  • num_gqa_groups (int, optional, defaults to 1) – Number of GQA (Grouped Query Attention) groups.

  • ffn_multipliers (float or list of float, optional, defaults to 4.0) – The multiplier for the feed-forward network.

  • ffn_with_glu (bool, optional, defaults to True) – Whether to use a gated linear unit (GLU) in the feed-forward network.

  • ffn_dim_divisor (int, optional, defaults to 256) – The divisor for the feed-forward network dimension.

  • activation_fn_name (str, optional, defaults to “swish”) – The activation function to use.

  • normalization_layer_name (str, optional, defaults to “rms_norm”) – The normalization layer to use.

  • normalize_qk_projections (bool, optional, defaults to False) – Whether to normalize the query and key projections.

  • share_input_output_layers (bool, optional, defaults to False) – Whether to share the input and output layers.

  • rope_freq_constant (int, optional, defaults to 10000) – The frequency constant for Rotary Position Embeddings (RoPE).

  • rope_max_length (int, optional, defaults to 4096) – The maximum length for RoPE.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • bits (int, optional) – The number of bits to quantize the model to.

attribute_map: Dict[str, str] = {'tie_word_embedding': 'share_input_output_layers'}#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'openelm'#
static rng_keys()[source]#
class easydel.__init__.OpenELMForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.OpenELMModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies#

Returns frequency values from the config.

class easydel.__init__.PartitionAxis(batch_axis: Optional[Union[Tuple[str, ...], str]] = ('fsdp', 'dp'), sequence_axis: Optional[Union[Tuple[str, ...], str]] = 'sp', query_sequence_axis: Optional[Union[Tuple[str, ...], str]] = 'sp', head_axis: Optional[Union[Tuple[str, ...], str]] = 'tp', key_sequence_axis: Optional[Union[Tuple[str, ...], str]] = 'sp', hidden_state_axis: Optional[Union[Tuple[str, ...], str]] = 'tp', attention_dim_axis: Optional[Union[Tuple[str, ...], str]] = None, bias_head_sequence_axis: Optional[Union[Tuple[str, ...], str]] = None, bias_key_sequence_axis: Optional[Union[Tuple[str, ...], str]] = None, generation_query_sequence_axis: Optional[Union[Tuple[str, ...], str]] = None, generation_head_axis: Optional[Union[Tuple[str, ...], str]] = 'tp', generation_key_sequence_axis: Optional[Union[Tuple[str, ...], str]] = 'sp', generation_attention_dim_axis: Optional[Union[Tuple[str, ...], str]] = None)[source]#

Bases: NamedTuple

A NamedTuple representing different axes of partitioning in a model.

Each field represents an axis and its corresponding partitioning strategy. The value of each field can be:

  • None: The axis is not partitioned.

  • str: The name of the single mesh dimension across which the axis is partitioned.

  • Tuple[str, …]: A tuple of mesh dimension names, indicating a sharding strategy

    where the axis is split across multiple mesh dimensions.

batch_axis#

Partitioning strategy for the batch dimension. Defaults to (“fsdp”, “dp”).

Type

Optional[Union[Tuple[str, …], str]]

sequence_axis#

Partitioning strategy for the sequence dimension. Defaults to “sp”.

Type

Optional[Union[Tuple[str, …], str]]

query_sequence_axis#

Partitioning strategy for the query sequence dimension. Defaults to “sp”.

Type

Optional[Union[Tuple[str, …], str]]

head_axis#

Partitioning strategy for the attention head dimension. Defaults to “tp”.

Type

Optional[Union[Tuple[str, …], str]]

key_sequence_axis#

Partitioning strategy for the key sequence dimension. Defaults to “sp”.

Type

Optional[Union[Tuple[str, …], str]]

hidden_state_axis#

Partitioning strategy for the hidden state dimension. Defaults to “tp”.

Type

Optional[Union[Tuple[str, …], str]]

attention_dim_axis#

Partitioning strategy for the attention dimension. Defaults to None.

Type

Optional[Union[Tuple[str, …], str]]

bias_head_sequence_axis#

Partitioning strategy for the bias head sequence dimension. Defaults to None.

Type

Optional[Union[Tuple[str, …], str]]

bias_key_sequence_axis#

Partitioning strategy for the bias key sequence dimension. Defaults to None.

Type

Optional[Union[Tuple[str, …], str]]

generation_query_sequence_axis#

Partitioning strategy for the query sequence dimension during generation. Defaults to None.

Type

Optional[Union[Tuple[str, …], str]]

generation_head_axis#

Partitioning strategy for the attention head dimension during generation. Defaults to “tp”.

Type

Optional[Union[Tuple[str, …], str]]

generation_key_sequence_axis#

Partitioning strategy for the key sequence dimension during generation. Defaults to “sp”.

Type

Optional[Union[Tuple[str, …], str]]

generation_attention_dim_axis#

Partitioning strategy for the attention dimension during generation. Defaults to None.

Type

Optional[Union[Tuple[str, …], str]]

attention_dim_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 6

batch_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 0

bias_head_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 7

bias_key_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 8

generation_attention_dim_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 12

generation_head_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 10

generation_key_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 11

generation_query_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 9

head_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 3

hidden_state_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 5

key_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 4

query_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 2

sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 1

class easydel.__init__.Phi3Config(vocab_size=32064, hidden_size=3072, intermediate_size=8192, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, resid_pdrop=0.0, embd_pdrop=0.0, attention_dropout=0.0, hidden_act='silu', max_position_embeddings=4096, original_max_position_embeddings=4096, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, bos_token_id=1, eos_token_id=32000, pad_token_id=32000, sliding_window=None, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32064) – Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 3072) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 8192) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.

  • resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • embd_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 4096) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • original_max_position_embeddings (int, optional, defaults to 4096) – The original maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 32000) – The id of the end-of-sequence token.

  • pad_token_id (int, optional, defaults to 32000) – The index of the padding token in the vocabulary.

  • sliding_window (int, optional) – The sliding window size.

  • bits (int, optional) – The number of bits to quantize the model to.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'phi3'#
class easydel.__init__.Phi3ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Phi3Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies#

Returns frequency values from the config.

class easydel.__init__.PhiConfig(vocab_size=51200, hidden_size=2048, intermediate_size=8192, num_hidden_layers=24, num_attention_heads=32, num_key_value_heads=None, resid_pdrop=0.0, embd_pdrop=0.0, attention_dropout=0.0, hidden_act='gelu_new', max_position_embeddings=2048, initializer_range=0.02, layer_norm_eps=1e-05, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, partial_rotary_factor=0.5, qk_layernorm=False, bos_token_id=1, eos_token_id=2, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 51200) – Vocabulary size of the Phi model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 8192) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 24) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.

  • resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • embd_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • hidden_act (str or function, optional, defaults to “gelu_new”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • partial_rotary_factor (float, optional, defaults to 0.5) – The factor for partial rotary embeddings.

  • qk_layernorm (bool, optional, defaults to False) – Whether to apply layer normalization to the query and key tensors.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • bits (int, optional) – The number of bits to quantize the model to.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

attribute_map: Dict[str, str] = {'hidden_size': 'n_embd', 'max_position_embeddings': 'n_positions', 'num_attention_heads': 'num_attention_heads', 'num_hidden_layers': 'num_hidden_layers'}#
get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'phi'#
class easydel.__init__.PhiForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.PhiModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies#

Returns frequency values from the config.

class easydel.__init__.PhiMoeConfig(vocab_size=32064, hidden_size=4096, intermediate_size=6400, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, hidden_act='silu', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=1000000.0, rope_scaling=None, sliding_window=None, attention_dropout=0.0, num_experts_per_tok=2, num_local_experts=16, output_router_logits=False, router_aux_loss_coef=0.001, router_jitter_noise=0.01, input_jitter_noise=0.0, attention_bias=False, embd_pdrop: float = 0.0, lm_head_bias=False, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32064) – Vocabulary size of the PhiMoE model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [PhiMoEModel]

  • hidden_size (int, optional, defaults to 4096) – Dimension of the hidden representations.

  • intermediate_size (int, optional, defaults to 6400) – Dimension of the MLP representations.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 8) – This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) in the decoder.

  • max_position_embeddings (int, optional, defaults to 4096*32) – The maximum sequence length that this model might ever be used with. Mixtral’s sliding window attention allows sequence of up to 4096*32 tokens.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-05) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – The id of the padding token.

  • bos_token_id (int, optional, defaults to 1) – The id of the “beginning-of-sequence” token.

  • eos_token_id (int, optional, defaults to 2) – The id of the “end-of-sequence” token.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether the model’s input and output word embeddings should be tied.

  • rope_theta (float, optional, defaults to 10000.0) – The base period of the RoPE embeddings.

  • rope_scaling (dict, optional) – The scaling strategy for the RoPE embeddings. If None, no scaling is applied. If a dictionary, it must contain the following keys: type, short_factor, long_factor, short_mscale, long_mscale and original_max_position_embeddings. The type must be longrope, the short_mscale and long_scale must be numbers, the short_factor and long_factor must be lists of numbers with the same length as half of the attention head size and the original_max_position_embeddings must be an integer.

  • sliding_window (int, optional) – Sliding window attention window size. If not specified, will default to 262144.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • num_experts_per_tok (int, optional, defaults to 2) – The number of experts to root per-token, can be also interpreted as the top-p routing parameter

  • num_local_experts (int, optional, defaults to 16) – Number of experts per Sparse MLP layer.

  • output_router_logits (bool, optional, defaults to False) – Whether or not the router logits should be returned by the model. Enabeling this will also allow the model to output the auxiliary loss. See [here]() for more details

  • router_aux_loss_coef (float, optional, defaults to 0.0) – The aux loss factor for the total loss.

  • router_jitter_noise (float, optional, defaults to 0.01) – Amount of noise to add to the router.

  • bits (int, optional) – The number of bits to quantize the model to.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

attach_custom_arguments(bits: Optional[int] = None, embd_pdrop: float = 0.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'phimoe'#
class easydel.__init__.PhiMoeForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.PhiMoeModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.PixtralVisionConfig(hidden_size: int = 1024, intermediate_size: int = 4096, num_hidden_layers: int = 24, num_attention_heads: int = 16, num_channels: int = 3, image_size: int = 1024, patch_size: int = 16, hidden_act: str = 'gelu', attention_dropout: float = 0.0, rope_theta: float = 10000.0, initializer_range: int = 0.02, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the configuration of a [PixtralVisionModel]. It is used to instantiate an Pixtral vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to the vision encoder used by Pixtral-12B.

e.g. [pixtral-hf/pixtral-9b](https://huggingface.co/pixtral-hf/pixtral-9b)

Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.

Parameters
  • hidden_size (int, optional, defaults to 1024) – Dimension of the hidden representations.

  • intermediate_size (int, optional, defaults to 4096) – Dimension of the MLP representations.

  • num_hidden_layers (int, optional, defaults to 24) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 16) – Number of attention heads in the Transformer encoder.

  • num_channels (int, optional, defaults to 3) – Number of input channels in the input images.

  • image_size (int, optional, defaults to 1024) – Max dimension of the input images.

  • patch_size (int, optional, defaults to 16) – Size of the image patches.

  • hidden_act (str, optional, defaults to “gelu”) – Activation function used in the hidden layers.

  • attention_dropout (float, optional, defaults to 0.0) – Dropout probability for the attention layers.

  • rope_theta (float, optional, defaults to 10000.0) – The base period of the RoPE embeddings.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

Example:

```python >>> from transformers import PixtralVisionModel, PixtralVisionConfig

>>> # Initializing a Pixtral-12B style configuration
>>> config = PixtralVisionConfig()
>>> # Initializing a model (with randomly initialized weights) from the configuration
>>> model = PixtralVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'pixtral'#
class easydel.__init__.PixtralVisionModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies#

Returns frequency values from the config.

class easydel.__init__.Qwen2Config(vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = 0.0, fcm_max_ratio: float = 0.0, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, scan_layers: bool = True, rope_scaling: Optional[Mapping[str, str | float]] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 151936) – Vocabulary size of the Qwen-2 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 22016) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 32) – Number of key and value heads for each attention layer in the Transformer encoder.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 32768) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • use_sliding_window (bool, optional, defaults to False) – Whether to use a sliding window attention.

  • sliding_window (int, optional, defaults to 4096) – The sliding window size.

  • max_window_layers (int, optional, defaults to 28) – The maximum number of layers to use for the sliding window attention.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • embd_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • fcm_min_ratio (float, optional, defaults to 0.0) – The minimum ratio for Flash Attention.

  • fcm_max_ratio (float, optional, defaults to 0.0) – The maximum ratio for Flash Attention.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • number_rep_kv (int, optional, defaults to 1) – Number of repetitions for the key and value vectors.

  • bits (int, optional) – The number of bits to quantize the model to.

  • scan_layers (bool, optional, defaults to True) – Whether to use the scan implementation for the layers.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

attach_custom_arguments(resid_pdrop: float = 0.0, embd_pdrop: float = 0.0, attention_dropout: float = 0.0, tie_word_embeddings: bool = False, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, fcm_min_ratio: float = 0.0, fcm_max_ratio: float = 0.0, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, rope_theta: float = 10000.0, hidden_act: str = 'silu', scan_layers: bool = True, rope_scaling: Optional[Mapping[str, str | float]] = None, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • resid_pdrop – float: Set the dropout rate for residual connections

  • embd_pdrop – float: Set the probability of dropping an embedding

  • attention_dropout – float: Set the probability of dropping out the attention layer

  • tie_word_embeddings – bool: Tie the word embeddings to the decoder

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • fcm_min_ratio – float: Control the minimum ratio of the number of chunks to be used in flash-based computation

  • fcm_max_ratio – float: Set the maximum ratio of the number of input tokens to output tokens

  • use_scan_mlp – bool: Determine whether to use the scan_mlp function or not

  • scan_mlp_chunk_size – int: Set the chunk size for scan_mlp

  • number_rep_kv – int: Determine how many times the key and value vectors are repeated

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

  • rope_theta – float : rope_theta for compute rope

  • hidden_act – str : hidden_act for mlp

  • scan_layers – bool: Determine whether to use scan layers or not

Return type

The following

get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'qwen2'#
static rng_keys()[source]#
class easydel.__init__.Qwen2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Qwen2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Qwen2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Qwen2MoeConfig(vocab_size=151936, hidden_size=2048, intermediate_size=5632, num_hidden_layers=24, num_attention_heads=16, num_key_value_heads=16, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, decoder_sparse_step=1, moe_intermediate_size=1408, shared_expert_intermediate_size=5632, num_experts_per_tok=4, num_experts=60, norm_topk_prob=False, output_router_logits=False, router_aux_loss_coef=0.001, mlp_only_layers=None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 151936) – Vocabulary size of the Qwen-2 MoE model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 5632) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 24) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 16) – Number of key and value heads for each attention layer in the Transformer encoder.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 32768) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • use_sliding_window (bool, optional, defaults to False) – Whether to use a sliding window attention.

  • sliding_window (int, optional, defaults to 4096) – The sliding window size.

  • max_window_layers (int, optional, defaults to 28) – The maximum number of layers to use for the sliding window attention.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • decoder_sparse_step (int, optional, defaults to 1) – The sparse step for the decoder.

  • moe_intermediate_size (int, optional, defaults to 1408) – The intermediate size of the MoE layer.

  • shared_expert_intermediate_size (int, optional, defaults to 5632) – The intermediate size of the shared expert.

  • num_experts_per_tok (int, optional, defaults to 4) – The number of experts per token.

  • num_experts (int, optional, defaults to 60) – The number of experts.

  • norm_topk_prob (bool, optional, defaults to False) – Whether to normalize the top-k probabilities.

  • output_router_logits (bool, optional, defaults to False) – Whether to output the router logits.

  • router_aux_loss_coef (float, optional, defaults to 0.001) – The coefficient for the router auxiliary loss.

  • mlp_only_layers (list of int, optional) – The layers that should only contain an MLP.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

Return type

The following

get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'qwen2_moe'#
static rng_keys()[source]#
class easydel.__init__.Qwen2MoeForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Qwen2MoeForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Qwen2MoeModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.Qwen2VLConfig(vocab_size=152064, hidden_size=8192, intermediate_size=29568, num_hidden_layers=80, num_attention_heads=64, num_key_value_heads=8, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, tie_word_embeddings=False, rope_theta=1000000.0, use_sliding_window=False, sliding_window=4096, max_window_layers=80, attention_dropout=0.0, vision_config=None, rope_scaling=None, vision_start_token_id=151652, vision_end_token_id=151653, vision_token_id=151654, image_token_id=151655, video_token_id=151656, **kwargs)[source]#

Bases: EasyDeLBaseConfig

This is the configuration class to store the configuration of a [Qwen2VLModel]. It is used to instantiate a Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).

Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 152064) – Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [Qwen2VLModel]

  • hidden_size (int, optional, defaults to 8192) – Dimension of the hidden representations.

  • intermediate_size (int, optional, defaults to 29568) – Dimension of the MLP representations.

  • num_hidden_layers (int, optional, defaults to 80) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 64) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 8) – This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 32.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) in the decoder.

  • max_position_embeddings (int, optional, defaults to 32768) – The maximum sequence length that this model might ever be used with.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-05) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether the model’s input and output word embeddings should be tied.

  • rope_theta (float, optional, defaults to 1000000.0) – The base period of the RoPE embeddings.

  • use_sliding_window (bool, optional, defaults to False) – Whether to use sliding window attention.

  • sliding_window (int, optional, defaults to 4096) – Sliding window attention (SWA) window size. If not specified, will default to 4096.

  • max_window_layers (int, optional, defaults to 80) – The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • vision_config (tp.Dict, optional) – The config for the visual encoder initialization.

  • rope_scaling (tp.Dict, optional) –

    Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer max_position_embeddings, we recommend you to update this value accordingly. Expected contents:

    rope_type (str):

    The sub-variant of RoPE to use. Can be one of [‘default’, ‘linear’, ‘dynamic’, ‘yarn’, ‘longrope’, ‘llama3’], with ‘default’ being the original RoPE implementation.

    factor (float, optional):

    Used with all rope types except ‘default’. The scaling factor to apply to the RoPE embeddings. In most scaling types, a factor of x will enable the model to handle sequences of length x * original maximum pre-trained length.

    original_max_position_embeddings (int, optional):

    Used with ‘dynamic’, ‘longrope’ and ‘llama3’. The original max position embeddings used during pretraining.

    attention_factor (float, optional):

    Used with ‘yarn’ and ‘longrope’. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the factor field to infer the suggested value.

    beta_fast (float, optional):

    Only used with ‘yarn’. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32.

    beta_slow (float, optional):

    Only used with ‘yarn’. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1.

    short_factor (tp.List[float], optional):

    Only used with ‘longrope’. The scaling factor to be applied to short contexts (< original_max_position_embeddings). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2

    long_factor (tp.List[float], optional):

    Only used with ‘longrope’. The scaling factor to be applied to long contexts (< original_max_position_embeddings). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2

    low_freq_factor (float, optional):

    Only used with ‘llama3’. Scaling factor applied to low frequency components of the RoPE

    high_freq_factor (float, optional):

    Only used with ‘llama3’. Scaling factor applied to high frequency components of the RoPE

```python >>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig

>>> # Initializing a Qwen2VL style configuration
>>> configuration = Qwen2VLConfig()
>>> # Initializing a model from the Qwen2-VL-7B style configuration
>>> model = Qwen2VLForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

keys_to_ignore_at_inference = ['past_key_values']#
model_type: str = 'qwen2_vl'#
sub_configs: Dict[str, 'PretrainedConfig'] = {'vision_config': <class 'easydel.__init__.modules.qwen2_vl.qwen2_vl_configuration.Qwen2VLVisionConfig'>}#
class easydel.__init__.Qwen2VLForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

get_decoder()[source]#
get_input_embeddings()[source]#
get_output_embeddings()[source]#
get_static_arguments()[source]#

return static arguments kwargs for jax.jit

loss_type = 'ForCausalLM'#
prepare_inputs_for_call(image_grid_thw: Optional[Union[Array, ndarray, bool, number]] = None, video_grid_thw: Optional[Union[Array, ndarray, bool, number]] = None, image_max_grid_size: int = None, video_max_grid_size: int = None, drop_ids: bool = True, **others)[source]#

update inputs for calling model

prepare_inputs_for_generation(input_ids, max_length, past_key_values=None, attention_mask=None, inputs_embeds=None, position_ids=None, pixel_values=None, pixel_values_videos=None, image_grid_thw=None, video_grid_thw=None, **kwargs)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds

Returns

A dictionary of the past_key_values, attention_mask and position ids

update_inputs_for_generation(model_outputs, model_kwargs)[source]#
class easydel.__init__.Qwen2VLModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.RewardConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: ~typing.Optional[str] = None, clip_grad: ~typing.Optional[float] = None, custom_scheduler: ~typing.Optional[~typing.Callable[[int], ~typing.Any]] = None, dataloader_num_workers: ~typing.Optional[int] = 0, dataloader_pin_memory: ~typing.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: ~typing.Optional[int] = None, evaluation_steps: ~typing.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: ~typing.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: ~typing.Optional[~typing.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: ~typing.Optional[dict] = None, learning_rate: float = 5e-05, learning_rate_end: ~typing.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: ~typing.Optional[~easydel.infra.loss_utils.LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: ~typing.Optional[int] = None, max_sequence_length: ~typing.Optional[int] = 4096, max_training_steps: ~typing.Optional[int] = None, model_name: str = 'BaseTrainer', model_parameters: ~typing.Optional[dict] = None, metrics_to_show_in_rich_pbar: ~typing.Optional[~typing.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: ~typing.Literal['adafactor', 'lion', 'adamw', 'rmsprop'] = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: ~typing.Any = None, process_zero_is_admin: bool = True, progress_bar_type: ~typing.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: ~typing.Optional[int] = None, save_total_limit: ~typing.Optional[int] = None, scheduler: ~typing.Literal['linear', 'cosine', 'none'] = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: ~typing.Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo', state_apply_fn_kwarguments_to_model: ~typing.Optional[dict] = None, step_partition_spec: ~jax._src.partition_spec.PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: ~typing.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: ~typing.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: ~typing.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: ~typing.Optional[~numpy.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: ~typing.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0)[source]#

Bases: TrainingArguments

Configuration class for the [RewardTrainer].

Parameters
  • model_name (str) – The name of the model. Defaults to “RewardTrainer”.

  • max_length (int, optional) – Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the limit. Defaults to 1024.

  • disable_dropout (bool, optional) – Whether to disable dropout in the model. Defaults to True.

  • dataset_num_proc (int, optional) – Number of processes to use for processing the dataset. Defaults to None.

  • center_rewards_coefficient (float, optional) – Coefficient to incentivize the reward model to output mean-zero rewards. Defaults to 0.1.

  • remove_unused_columns (bool, optional) – Whether to remove the columns that are not used by the model’s forward pass. Can be True only if the dataset is pretokenized. Defaults to False.

center_rewards_coefficient: Optional[float] = Field(name=None,type=None,default=0.1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Coefficient to incentivize the reward model to output mean-zero rewards.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
dataset_num_proc: Optional[int] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of processes to use for processing the dataset.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
disable_dropout: bool = Field(name=None,type=None,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to disable dropout in the model.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
max_sequence_length: Optional[int] = Field(name=None,type=None,default=1024,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the limit.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
model_name: str = Field(name=None,type=None,default='RewardTrainer',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The name of the model.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
remove_unused_columns: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if the dataset is pretokenized."}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
replace(**kwargs)#
class easydel.__init__.RewardTrainer(arguments: RewardConfig, processing_class: Any, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, data_collator: Optional[RewardDataCollatorWithPadding] = None)[source]#

Bases: Trainer

This trainer extends the Trainer and provides functionalities.

configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method prepares the functions that will be used during training and evaluation. It sets up sharding for the model parameters and optimizer state, JIT-compiles the training and evaluation functions with the appropriate static arguments and sharding constraints, and also sets up the checkpoint manager.

Returns

An object containing:
  • sharded_training_step_function: The compiled training step function.

  • sharded_evaluation_step_function: The compiled evaluation step function.

  • mesh: The device mesh used for computation.

  • checkpoint_manager: The checkpointer for saving/loading model state.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length, truncation_mode='keep_end')[source]#

Creates a collate/collect function to process batches of data for training or evaluation.

This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.

Returns

A function that takes a batch (list of dicts) and returns a processed dict of arrays.

Return type

tp.Callable

class easydel.__init__.RobertaConfig(vocab_size=50265, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act='gelu', hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=514, type_vocab_size=1, initializer_range=0.02, layer_norm_eps=1e-05, pad_token_id=1, bos_token_id=0, eos_token_id=2, position_embedding_type='absolute', use_cache=True, classifier_dropout=None, gradient_checkpointing='nothing_saveable', **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information. :param vocab_size: Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by

the inputs_ids passed when calling RobertaModel.

Parameters
  • hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.

  • num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.

  • intermediate_size (int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • hidden_act (str or function, optional, defaults to "gelu") – The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported.

  • hidden_dropout_prob (float, optional, defaults to 0.1) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • attention_probs_dropout_prob (float, optional, defaults to 0.1) – The dropout ratio for the attention probabilities.

  • max_position_embeddings (int, optional, defaults to 514) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).

  • type_vocab_size (int, optional, defaults to 1) – The vocabulary size of the token_type_ids passed when calling RobertaModel.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • position_embedding_type (str, optional, defaults to "absolute") – Type of position embedding. Choose one of "absolute", "relative_key", "relative_key_query". For positional embeddings use "absolute". For more information on "relative_key", please refer to [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). For more information on "relative_key_query", please refer to Method 4 in [Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • classifier_dropout (float, optional) – The dropout ratio for the classification head.

  • gradient_checkpointing (str, optional, defaults to "nothing_saveable") – What to save during gradient checkpointing. Choose one of "nothing_saveable", "first_half_saveable", "full_saveable".

attach_custom_arguments(gradient_checkpointing='nothing_saveable', **kwargs)[source]#
get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'roberta'#
class easydel.__init__.RobertaForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.RobertaForMultipleChoice(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.RobertaForQuestionAnswering(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.RobertaForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.RobertaForTokenClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.SFTConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: ~typing.Optional[str] = None, clip_grad: ~typing.Optional[float] = None, custom_scheduler: ~typing.Optional[~typing.Callable[[int], ~typing.Any]] = None, dataloader_num_workers: ~typing.Optional[int] = 0, dataloader_pin_memory: ~typing.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: ~typing.Optional[int] = None, evaluation_steps: ~typing.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: ~typing.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: ~typing.Optional[~typing.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: ~typing.Optional[dict] = None, learning_rate: float = 5e-05, learning_rate_end: ~typing.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: ~typing.Optional[~easydel.infra.loss_utils.LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: ~typing.Optional[int] = None, max_sequence_length: ~typing.Optional[int] = 4096, max_training_steps: ~typing.Optional[int] = None, model_name: str = 'BaseTrainer', model_parameters: ~typing.Optional[dict] = None, metrics_to_show_in_rich_pbar: ~typing.Optional[~typing.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: ~typing.Literal['adafactor', 'lion', 'adamw', 'rmsprop'] = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: ~typing.Any = None, process_zero_is_admin: bool = True, progress_bar_type: ~typing.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: ~typing.Optional[int] = None, save_total_limit: ~typing.Optional[int] = None, scheduler: ~typing.Literal['linear', 'cosine', 'none'] = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: ~typing.Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo', state_apply_fn_kwarguments_to_model: ~typing.Optional[dict] = None, step_partition_spec: ~jax._src.partition_spec.PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: ~typing.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: ~typing.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: ~typing.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: ~typing.Optional[~numpy.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: ~typing.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0)[source]#

Bases: TrainingArguments

Configuration class for the [SFTTrainer].

Parameters
  • model_name (str) – The name of the model. Defaults to “SFTTrainer”.

  • dataset_text_field (str, optional) – Name of the text field of the dataset. If provided, the trainer will automatically create a [ConstantLengthDataset] based on dataset_text_field. Defaults to None.

  • packing (bool, optional) – Controls whether the [ConstantLengthDataset] packs the sequences of the dataset. Defaults to False.

  • learning_rate (float, optional) – Initial learning rate for [AdamW] optimizer. The default value replaces that of [~transformers.TrainingArguments]. Defaults to 2e-5.

  • dataset_num_proc (int, optional) – Number of processes to use for processing the dataset. Only used when packing=False. Defaults to None.

  • dataset_batch_size (int, optional) – Number of examples to tokenize per batch. If dataset_batch_size <= 0 or dataset_batch_size is None, tokenizes the full dataset as a single batch. Defaults to 1000.

  • dataset_kwargs (dict[str, Any], optional) – Dictionary of optional keyword arguments to pass when creating packed or non-packed datasets. Defaults to None.

  • eval_packing (bool, optional) – Whether to pack the eval dataset. If None, uses the same value as packing. Defaults to None.

  • num_of_sequences (int, optional) – Number of sequences to use for the [ConstantLengthDataset]. Defaults to 1024.

  • chars_per_token (float, optional) – Number of characters per token to use for the [ConstantLengthDataset]. See [chars_token_ratio](huggingface/trl) for more details. Defaults to 3.6.

add_special_tokens: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to add special tokens.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
chars_per_token: float = Field(name=None,type=None,default=3.6,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of characters per token to use for the ConstantLengthDataset.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
dataset_batch_size: int = Field(name=None,type=None,default=1000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of examples to tokenize per batch.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
dataset_kwargs: Optional[dict[str, Any]] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Dictionary of optional keyword arguments to pass when creating datasets.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
dataset_num_proc: Optional[int] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of processes to use for processing the dataset.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
dataset_text_field: Optional[str] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Name of the text field of the dataset.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
eval_packing: Optional[bool] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to pack the eval dataset. If None, uses the same value as packing.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
learning_rate: float = Field(name=None,type=None,default=2e-05,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Initial learning rate for the AdamW optimizer.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
model_name: str = Field(name=None,type=None,default='SFTTrainer',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The name of the model.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
num_of_sequences: int = Field(name=None,type=None,default=1024,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of sequences to use for the ConstantLengthDataset.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
packing: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Controls whether the sequences of the dataset are packed.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
replace(**kwargs)#
class easydel.__init__.SFTTrainer(arguments: SFTConfig, processing_class: Any, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, formatting_func: Optional[Callable] = None, data_collator: Optional[DataCollatorForCompletionOnlyLM] = None)[source]#

Bases: Trainer

Trainer class for Supervised Fine-Tuning (SFT) of language models.

This trainer extends the Trainer and provides functionalities specific to supervised fine-tuning tasks.

class easydel.__init__.SiglipConfig(text_config=None, vision_config=None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

[SiglipConfig] is the configuration class to store the configuration of a [SiglipModel]. It is used to instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • text_config (dict, optional) – Dictionary of configuration options used to initialize [SiglipTextConfig].

  • vision_config (dict, optional) – Dictionary of configuration options used to initialize [SiglipVisionConfig].

  • kwargs (optional) – Dictionary of keyword arguments.

classmethod from_text_vision_configs(text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs)[source]#

Instantiate a [SiglipConfig] (or a derived class) from siglip text model configuration and siglip vision model configuration.

Returns

An instance of a configuration object

Return type

[SiglipConfig]

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'siglip'#
sub_configs: Dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.__init__.modules.siglip.configuration_siglip.SiglipTextConfig'>, 'vision_config': <class 'easydel.__init__.modules.siglip.configuration_siglip.SiglipVisionConfig'>}#
class easydel.__init__.SiglipForImageClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.SiglipModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

get_image_features(pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False) Union[Array, ndarray, bool, number][source]#
get_text_features(input_ids: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, position_ids: Optional[Union[Array, ndarray, bool, number]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None) Union[Array, ndarray, bool, number][source]#
class easydel.__init__.SiglipTextConfig(vocab_size=32000, hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, max_position_embeddings=64, hidden_act='gelu_pytorch_tanh', layer_norm_eps=1e-06, attention_dropout=0.0, pad_token_id=1, bos_token_id=49406, eos_token_id=49407, projection_size=None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [SiglipModel].

  • hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.

  • max_position_embeddings (int, optional, defaults to 64) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).

  • hidden_act (str or function, optional, defaults to “gelu_pytorch_tanh”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “selu” and “gelu_new” “quick_gelu” are supported.

  • layer_norm_eps (float, optional, defaults to 1e-06) – The epsilon used by the layer normalization layers.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • pad_token_id (int, optional, defaults to 1) – The id of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 49406) – The id of the beginning-of-sequence token in the vocabulary.

  • eos_token_id (int, optional, defaults to 49407) – The id of the end-of-sequence token in the vocabulary.

  • projection_size (int, optional, defaults to hidden_size) – The size of the projection head.

Example:

```python >>> from transformers import SiglipTextConfig, SiglipTextModel

>>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipTextConfig()
>>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
base_config_key: str = 'text_config'#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'siglip_text_model'#
class easydel.__init__.SiglipTextModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.SiglipVisionConfig(hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, num_channels=3, image_size=224, patch_size=16, hidden_act='gelu_pytorch_tanh', layer_norm_eps=1e-06, attention_dropout=0.0, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_channels (int, optional, defaults to 3) – Number of channels in the input images.

  • image_size (int, optional, defaults to 224) – The size (resolution) of each image.

  • patch_size (int, optional, defaults to 16) – The size (resolution) of each patch.

  • hidden_act (str or function, optional, defaults to “gelu_pytorch_tanh”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “selu” and “gelu_new” “quick_gelu” are supported.

  • layer_norm_eps (float, optional, defaults to 1e-06) – The epsilon used by the layer normalization layers.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

```

base_config_key: str = 'vision_config'#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'siglip_vision_model'#
class easydel.__init__.SiglipVisionModel(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.__init__.StableLmConfig(vocab_size=50304, intermediate_size=6912, hidden_size=2560, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, hidden_act='silu', max_position_embeddings=4096, initializer_range=0.02, layer_norm_eps=1e-05, use_cache=True, tie_word_embeddings=False, rope_theta=10000, rope_scaling=None, use_qkv_bias=False, qk_layernorm=False, use_parallel_residual=False, hidden_dropout=0.0, attention_dropout=0.0, partial_rotary_factor=0.25, bos_token_id=0, eos_token_id=0, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 50304) – Vocabulary size of the StableLM model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [~easydel.modules.StableLmModel].

  • hidden_size (int, optional, defaults to 2560) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 6912) – Dimensionality of the “intermediate” (often named feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 32) – Number of key-value heads for each attention layer in the Transformer encoder.

  • hidden_act (str or tp.Callable, optional, defaults to “silu”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 4096) – The maximum sequence length that this model might ever be used with.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • layer_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the layer normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models).

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (int, optional, defaults to 10000) – The theta value for the rotary position embeddings.

  • rope_scaling (str, optional) – The scaling to use for the rotary position embeddings.

  • qk_layernorm (bool, optional, defaults to False) – Whether to use layer normalization on the queries and keys in the attention layer.

  • use_parallel_residual (bool, optional, defaults to False) – Whether to use a parallel residual connection in the attention layer.

  • hidden_dropout (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • partial_rotary_factor (float, optional, defaults to 0.25) – The factor to scale the partial rotary embeddings by.

  • bos_token_id (int, optional, defaults to 0) – The id for the beginning of stream token.

  • eos_token_id (int, optional, defaults to 0) – The id for the end of stream token.

  • bits (int, optional) – The number of bits to quantize the model to. If None, the model is not quantized.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – What to save during gradient checkpointing. Choose one of “nothing_saveable”, “first_half_saveable”, “full_saveable”.

get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'stablelm'#
class easydel.__init__.StableLmForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.StableLmModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies#

Returns frequency values from the config.

class easydel.__init__.TaskType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

AUDIO_CLASSIFICATION = 'audio-classification'#
BASE_MODULE = 'base-module'#
BASE_VISION = 'vision-module'#
CAUSAL_LM = 'causal-language-model'#
IMAGE_CLASSIFICATION = 'image-classification'#
IMAGE_TEXT_TO_TEXT = 'image-text-to-text'#
SEQUENCE_CLASSIFICATION = 'sequence-classification'#
SEQUENCE_TO_SEQUENCE = 'sequence-to-sequence'#
SPEECH_SEQUENCE_TO_SEQUENCE = 'speech-sequence-to-sequence'#
VISION_LM = 'vision-language-model'#
ZERO_SHOT_IMAGE_CLASSIFICATION = 'zero-shot-image-classification'#
class easydel.__init__.Trainer(arguments: tp.Optional[TrainingArguments] = None, model_state: tp.Optional[EasyDeLState] = None, model: tp.type[EasyDeLBaseModule] = None, dataset_train: tp.Optional[Dataset] = None, dataset_eval: tp.Optional[Dataset] = None, data_collator: tp.Optional[tp.Callable] = None, finetune: bool = True, checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]] = None, **deprecated_kwargs)[source]#

Bases: BaseTrainer

configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method prepares the functions that will be used during training and evaluation. It sets up sharding for the model parameters and optimizer state, JIT-compiles the training and evaluation functions with the appropriate static arguments and sharding constraints, and also sets up the checkpoint manager.

Returns

An object containing:
  • sharded_training_step_function: The compiled training step function.

  • sharded_evaluation_step_function: The compiled evaluation step function.

  • mesh: The device mesh used for computation.

  • checkpoint_manager: The checkpointer for saving/loading model state.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Creates a collate/collect function to process batches of data for training or evaluation.

This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.

Returns

A function that takes a batch (list of dicts) and returns a processed dict of arrays.

Return type

tp.Callable

eval(model_state: EasyDeLState) Iterator[dict][source]#

Evaluates the model using the provided model state.

This method iterates over the evaluation dataset, performs forward passes, calculates evaluation metrics, logs the metrics, and yields the metrics for each evaluation step.

Parameters

model_state (EasyDeLState) – The state of the model (including parameters and configuration) to be used for evaluation.

Yields

Iterator[dict] – An iterator yielding a dictionary of evaluation metrics for each evaluation step.

Raises

AssertionError – If the evaluation dataloader is not set.

train() TrainerOutput[source]#

Executes the complete training process.

This method sets up initial metrics and logging, runs the training loop, and finalizes training. It calls the training hook at the beginning and returns a TrainerOutput object at the end.

Returns

An object containing the final training state, metrics, and any additional outputs.

Return type

TrainerOutput

class easydel.__init__.TrainingArguments(auto_shard_states: 'bool' = True, aux_loss_enabled: 'bool' = False, backend: 'tp.Optional[str]' = None, clip_grad: 'tp.Optional[float]' = None, custom_scheduler: 'tp.Optional[tp.Callable[[int], tp.Any]]' = None, dataloader_num_workers: 'tp.Optional[int]' = 0, dataloader_pin_memory: 'tp.Optional[bool]' = False, do_eval: 'bool' = False, do_last_save: 'bool' = True, do_train: 'bool' = True, eval_batch_size: 'tp.Optional[int]' = None, evaluation_steps: 'tp.Optional[int]' = None, extra_optimizer_kwargs: 'dict' = <factory>, frozen_parameters: 'tp.Optional[str]' = None, gradient_accumulation_steps: 'int' = 1, ids_to_pop_from_dataset: 'tp.Optional[tp.List[str]]' = <factory>, is_fine_tuning: 'bool' = True, init_tx: 'bool' = True, jax_distributed_config: 'tp.Optional[dict]' = None, learning_rate: 'float' = 5e-05, learning_rate_end: 'tp.Optional[float]' = None, log_all_workers: 'bool' = False, log_grad_norms: 'bool' = True, report_metrics: 'bool' = True, log_steps: 'int' = 10, loss_config: 'tp.Optional[LossConfig]' = None, low_mem_usage: 'bool' = True, max_evaluation_steps: 'tp.Optional[int]' = None, max_sequence_length: 'tp.Optional[int]' = 4096, max_training_steps: 'tp.Optional[int]' = None, model_name: 'str' = 'BaseTrainer', model_parameters: 'tp.Optional[dict]' = None, metrics_to_show_in_rich_pbar: 'tp.Optional[tp.List[str]]' = None, num_train_epochs: 'int' = 10, offload_dataset: 'bool' = False, offload_device_type: 'str' = 'cpu', offload_device_index: 'int' = 0, optimizer: 'AVAILABLE_OPTIMIZERS' = <EasyDeLOptimizers.ADAMW: 'adamw'>, performance_mode: 'bool' = False, pruning_module: 'tp.Any' = None, process_zero_is_admin: 'bool' = True, progress_bar_type: "tp.Literal['tqdm', 'rich', 'json']" = 'tqdm', remove_ckpt_after_load: 'bool' = False, remove_unused_columns: 'bool' = True, report_steps: 'int' = 5, save_directory: 'str' = 'EasyDeL-Checkpoints', save_optimizer_state: 'bool' = True, save_steps: 'tp.Optional[int]' = None, save_total_limit: 'tp.Optional[int]' = None, scheduler: 'AVAILABLE_SCHEDULERS' = <EasyDeLSchedulers.NONE: 'None'>, sparsify_module: 'bool' = False, sparse_module_type: 'AVAILABLE_SPARSE_MODULE_TYPES' = 'bcoo', state_apply_fn_kwarguments_to_model: 'tp.Optional[dict]' = None, step_partition_spec: 'PartitionSpec' = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: 'tp.Optional[int]' = None, shuffle_train_dataset: 'bool' = True, total_batch_size: 'int' = 32, training_time_limit: 'tp.Optional[str]' = None, train_on_inputs: 'bool' = True, truncation_mode: "tp.Literal['keep_end', 'keep_start']" = 'keep_end', tx_mu_dtype: 'tp.Optional[jnp.dtype]' = None, track_memory: 'bool' = False, use_data_collactor: 'bool' = True, use_wandb: 'bool' = True, verbose: 'bool' = True, wandb_entity: 'tp.Optional[str]' = None, warmup_steps: 'int' = 0, weight_decay: 'float' = 0.01, weight_distribution_pattern: 'str' = '.*?(layernorm|norm).*?', weight_distribution_log_steps: 'int' = 0)[source]#

Bases: object

auto_shard_states: bool = True#
aux_loss_enabled: bool = False#
backend: Optional[str] = None#
clip_grad: Optional[float] = None#
custom_scheduler: Optional[Callable[[int], Any]] = None#
dataloader_num_workers: Optional[int] = 0#
dataloader_pin_memory: Optional[bool] = False#
do_eval: bool = False#
do_last_save: bool = True#
do_train: bool = True#
ensure_checkpoint_path()[source]#

Creates the checkpoint directory if it doesn’t exist.

ensure_training_time_limit(time_passed)[source]#
eval_batch_size: Optional[int] = None#
evaluation_steps: Optional[int] = None#
extra_optimizer_kwargs: dict#
classmethod from_dict(config: Dict[str, Any]) TrainingArguments[source]#

Creates a TrainingArguments instance from a dictionary.

Parameters

config (tp.Dict[str, tp.Any]) – The configuration dictionary.

Returns

A TrainingArguments object initialized with values from the dictionary.

Return type

TrainingArguments

frozen_parameters: Optional[str] = None#
get_optimizer_and_scheduler(steps: Optional[int] = None)[source]#

Returns the configured optimizer and learning rate scheduler.

Parameters

steps (tp.Optional[int]) – The number of training steps. If not provided, uses the value from self.optimizer_kwargs.

Returns

A tuple containing the optimizer and scheduler.

Return type

tuple

get_path() Path[source]#

Returns the path to the checkpoint directory.

Returns

The path to the checkpoint directory.

Return type

Path

get_streaming_checkpointer()[source]#

Returns the checkpoint manager, responsible for saving model checkpoints.

Returns

The checkpoint manager.

Return type

CheckpointManager

get_tensorboard()[source]#

Returns the TensorBoard SummaryWriter, used for logging metrics.

Returns

The TensorBoard SummaryWriter.

Return type

flax.metrics.tensorboard.SummaryWriter

get_wandb_init()[source]#

Initializes Weights & Biases for experiment tracking if enabled.

Returns

The WandB run object if initialized, else None.

Return type

tp.Optional[wandb.sdk.wandb_run.Run]

gradient_accumulation_steps: int = 1#
ids_to_pop_from_dataset: Optional[List[str]]#
init_tx: bool = True#
is_fine_tuning: bool = True#
property is_process_zero#
jax_distributed_config: Optional[dict] = None#
learning_rate: float = 5e-05#
learning_rate_end: Optional[float] = None#
log_all_workers: bool = False#
log_grad_norms: bool = True#
log_metrics(metrics: Any, step: int, log_as: Optional[Literal['summary', 'config']] = None)[source]#

Logs training metrics to Weights & Biases and/or TensorBoard.

Parameters
  • metrics (tp.Dict[str, tp.Union[float, tp.List, tp.Tuple, np.ndarray, 'jnp.ndarray', 'torch.Tensor']]) – A dictionary where keys are metric names and values are metric values.

  • step (int) – The current training step or iteration.

log_steps: int = 10#
log_weight_distribution(state, step: int)[source]#
loss_config: Optional[LossConfig] = None#
low_mem_usage: bool = True#
max_evaluation_steps: Optional[int] = None#
max_sequence_length: Optional[int] = 4096#
max_training_steps: Optional[int] = None#
metrics_to_show_in_rich_pbar: Optional[List[str]] = None#
model_name: str = 'BaseTrainer'#
model_parameters: Optional[dict] = None#
num_train_epochs: int = 10#
offload_dataset: bool = False#
property offload_device#
offload_device_index: int = 0#
offload_device_type: str = 'cpu'#
optimizer: Literal['adafactor', 'lion', 'adamw', 'rmsprop'] = 'adamw'#
performance_mode: bool = False#
process_zero_is_admin: bool = True#
progress_bar_type: Literal['tqdm', 'rich', 'json'] = 'tqdm'#
pruning_module: Any = None#
remove_ckpt_after_load: bool = False#
remove_unused_columns: bool = True#
replace(**kwargs)#
report_metrics: bool = True#
report_steps: int = 5#
save_directory: str = 'EasyDeL-Checkpoints'#
save_optimizer_state: bool = True#
save_steps: Optional[int] = None#
save_total_limit: Optional[int] = None#
scheduler: Literal['linear', 'cosine', 'none'] = 'None'#
shuffle_train_dataset: bool = True#
sparse_module_type: Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo'#
sparsify_module: bool = False#
state_apply_fn_kwarguments_to_model: Optional[dict] = None#
step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp')#
step_start_point: Optional[int] = None#
to_dict() Dict[str, Any][source]#

Converts the TrainingArguments object into a dictionary.

Returns

A dictionary representation of the TrainingArguments.

Return type

tp.Dict[str, tp.Any]

total_batch_size: int = 32#
track_memory: bool = False#
train_on_inputs: bool = True#
training_time_limit: Optional[str] = None#
property training_time_seconds: int#
truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end'#
tx_mu_dtype: Optional[dtype] = None#
use_data_collactor: bool = True#
use_wandb: bool = True#
verbose: bool = True#
wandb_entity: Optional[str] = None#
warmup_steps: int = 0#
weight_decay: float = 0.01#
weight_distribution_log_steps: int = 0#
weight_distribution_pattern: str = '.*?(layernorm|norm).*?'#
class easydel.__init__.WhisperConfig(vocab_size=51865, num_mel_bins=80, encoder_layers=4, encoder_attention_heads=6, decoder_layers=4, decoder_attention_heads=6, decoder_ffn_dim=1536, encoder_ffn_dim=1536, encoder_layerdrop=0.0, decoder_layerdrop=0.0, decoder_start_token_id=50257, use_cache=True, is_encoder_decoder=True, activation_function='gelu', d_model=384, dropout=0.0, attention_dropout=0.0, activation_dropout=0.0, init_std=0.02, scale_embedding=False, max_source_positions=1500, max_target_positions=448, pad_token_id=50256, bos_token_id=50256, eos_token_id=50256, suppress_tokens=None, begin_suppress_tokens=[220, 50256], use_weighted_layer_sum=False, classifier_proj_size=256, apply_spec_augment=False, mask_time_prob=0.05, mask_time_length=10, mask_time_min_masks=2, mask_feature_prob=0.0, mask_feature_length=10, mask_feature_min_masks=0, median_filter_width=7, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 51865) – Vocabulary size of the Whisper model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [~easydel.modules.WhisperModel].

  • num_mel_bins (int, optional, defaults to 80) – Number of mel bins used by the feature extractor.

  • encoder_layers (int, optional, defaults to 6) – Number of encoder layers.

  • encoder_attention_heads (int, optional, defaults to 4) – Number of attention heads for each attention layer in the Transformer encoder.

  • decoder_layers (int, optional, defaults to 6) – Number of decoder layers.

  • decoder_attention_heads (int, optional, defaults to 4) – Number of attention heads for each attention layer in the Transformer decoder.

  • decoder_ffn_dim (int, optional, defaults to 1536) – Dimensionality of the decoder feed-forward network (FFN) layer.

  • encoder_ffn_dim (int, optional, defaults to 1536) – Dimensionality of the encoder feed-forward network (FFN) layer.

  • encoder_layerdrop (float, optional, defaults to 0.0) – The LayerDrop probability for the encoder. See the [LayerDrop paper](https://arxiv.org/abs/1909.11556) for more details.

  • decoder_layerdrop (float, optional, defaults to 0.0) – The LayerDrop probability for the decoder. See the [LayerDrop paper](https://arxiv.org/abs/1909.11556) for more details.

  • d_model (int, optional, defaults to 256) – Dimensionality of the layers and the pooler layer.

  • activation_function (str, optional, defaults to “gelu”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “silu” and “gelu_new” are supported.

  • dropout (float, optional, defaults to 0.1) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • activation_dropout (float, optional, defaults to 0.0) – The dropout ratio for activations inside the fully connected layer.

  • init_std (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • scale_embedding (bool, optional, defaults to False) – Scale embeddings by dividing by sqrt(d_model).

  • max_source_positions (int, optional, defaults to 1500) – The maximum sequence length allowed for the source text input to the model. tp.Any longer inputs will be truncated.

  • max_target_positions (int, optional, defaults to 448) – The maximum sequence length allowed for the target text input to the model. tp.Any longer inputs will be truncated.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models).

  • apply_spec_augment (bool, optional, defaults to False) – Whether to apply SpecAugment data augmentation.

  • mask_time_prob (float, optional, defaults to 0.05) – Propability of each feature vector along the time axis to be chosen as the start of the vector span to be masked. Approximately mask_time_prob * sequence_length // mask_time_length feature vectors will be masked along the time axis. This is only relevant if apply_spec_augment is set to True.

  • mask_time_length (int, optional, defaults to 10) – Length of vector span along the time axis.

  • mask_time_min_masks (int, optional, defaults to 2) – The minimum number of masks of length mask_feature_length generated along the time axis, each time mask, the mask will be filled with floats sampled in (random_lower_bound, random_upper_bound).

  • mask_feature_prob (float, optional, defaults to 0.0) – Propability of each feature vector along the feature axis to be chosen as the start of the vector span to be masked. Approximately mask_time_prob * hidden_size // mask_feature_length feature vectors will be masked along the time axis. This is only relevant if apply_spec_augment is set to True.

  • mask_feature_length (int, optional, defaults to 10) – Length of vector span along the feature axis.

  • mask_feature_min_masks (int, optional, defaults to 0) – The minimum number of masks of length mask_feature_length generated along the feature axis, each time mask, the mask will be filled with floats sampled in (random_lower_bound, random_upper_bound).

  • median_filter_width (int, optional, defaults to 7) – The width of the median filter applied to the mask.

  • bits (int, optional) – The number of bits to quantize the model to. If None, the model is not quantized.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – What to save during gradient checkpointing. Choose one of “nothing_saveable”, “first_half_saveable”, “full_saveable”.

attribute_map: Dict[str, str] = {'hidden_size': 'd_model', 'num_attention_heads': 'encoder_attention_heads'}#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'whisper'#
class easydel.__init__.WhisperForAudioClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.WhisperForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#

basic compute_loss call

decode(decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[Array] = None, decoder_attention_mask: Optional[Array] = None, decoder_position_ids: Optional[Array] = None, past_key_values: Optional[dict] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None)[source]#
encode(input_features: Array, attention_mask: Optional[Array] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs)[source]#
generate(input_features, generation_config=None, logits_processor=None, return_timestamps=None, task=None, language=None, is_multilingual=None, **kwargs)[source]#

Generates sequences of token ids for models with a language modeling head.

Parameters
  • input_ids (chex.Array of shape (batch_size, sequence_length)) – The sequence used as a prompt for the generation.

  • generation_config (~generation.GenerationConfig, optional) – The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them. If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [~generation.GenerationConfig]’s default values, whose documentation should be checked to parameterize generation.

  • trace (bool, optional, defaults to True) – Whether to trace generation. Setting trace=False should only be used for debugging and will lead to a considerably slower runtime.

  • logits_processor (`FlaxLogitsProcessorList `, optional) – Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.

  • kwargs (tp.Dict[str, Any], optional) – Ad hoc parametrization of generate_config and/or additional model-specific kwargs that will be forwarded to the forward function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with decoder_.

Returns

[~utils.ModelOutput].

loss_type = 'ForCausalLM'#
prepare_inputs_for_generation(decoder_input_ids, max_length, attention_mask: Optional[Array] = None, decoder_attention_mask: Optional[Array] = None, encoder_outputs=None, **kwargs)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds

Returns

A dictionary of the past_key_values, attention_mask and position ids

update_inputs_for_generation(model_outputs, model_kwargs)[source]#
class easydel.__init__.WhisperTimeStampLogitsProcessor(generate_config, model_config, decoder_input_length)[source]#

Bases: FlaxLogitsProcessor

Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log probs to inf so that they are sampled at their corresponding index.

Parameters

generate_config (GenerateConfig) –

The generate config used to generate the output. The following parameters are required:
eos_token_id (int, optional, defaults to 50257):

The id of the end-of-sequence token.

no_timestamps_token_id (int, optional, defaults to 50363):

The id of the “<|notimestamps|>” token.

max_initial_timestamp_index (int, optional, defaults to 1):

Used to set the maximum value of the initial timestamp. This is used to prevent the model from predicting timestamps that are too far in the future.

class easydel.__init__.Xerxes2Config(vocab_size: int = 256128, hidden_size: int = 4096, intermediate_size: int = 16384, num_hidden_layers: int = 32, num_attention_heads: int = 32, max_position_embeddings: int = 16384, initializer_range: float = 0.02, rms_norm_eps: float = 1e-06, use_cache: bool = True, pad_token_id: int = 0, eos_token_id: int = 1, bos_token_id: int = 2, tie_word_embeddings: bool = False, rope_theta: float = 10000.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, scan_layers: bool = False, q_lora_dim: Optional[int] = 1536, kv_lora_dim: int = 512, qk_rope_head_dim: int = 64, qk_nope_head_dim: int = 128, vhead_dim: int = 128, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 256128) – Vocabulary size of the xerxes model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 16384) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 16) – Number of key and value heads for each attention layer in the Transformer encoder.

  • head_dim (int, optional, defaults to 256) – Dimensionality of the attention head.

  • max_position_embeddings (int, optional, defaults to 6144) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.

  • eos_token_id (int, optional, defaults to 1) – The index of the end of sequence token in the vocabulary.

  • bos_token_id (int, optional, defaults to 2) – The index of the beginning of sequence token in the vocabulary.

  • tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • softmax_scale (float, optional, defaults to 14.9666295471) – softmax scale for attention module.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

  • scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation of the layers.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'xerxes2'#
static rng_keys()[source]#
class easydel.__init__.Xerxes2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

init_cache(batch_size: int, max_length: int)[source]#
class easydel.__init__.Xerxes2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies: Array#

Returns frequency values from the config.

class easydel.__init__.XerxesConfig(vocab_size=256128, hidden_size=4096, intermediate_size=16384, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, head_dim=144, max_position_embeddings=16384, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, num_local_experts: int = 4, xe_moe: bool = True, num_experts_per_tok: int = 2, tie_word_embeddings=False, rope_theta=10000.0, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, scan_layers: bool = False, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 256128) – Vocabulary size of the xerxes model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 16384) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 16) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 16) – Number of key and value heads for each attention layer in the Transformer encoder.

  • head_dim (int, optional, defaults to 256) – Dimensionality of the attention head.

  • max_position_embeddings (int, optional, defaults to 6144) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-6) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional, defaults to 0) – The index of the padding token in the vocabulary.

  • eos_token_id (int, optional, defaults to 1) – The index of the end of sequence token in the vocabulary.

  • bos_token_id (int, optional, defaults to 2) – The index of the beginning of sequence token in the vocabulary.

  • tie_word_embeddings (bool, optional, defaults to True) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • softmax_scale (float, optional, defaults to 14.9666295471) – softmax scale for attention module.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • bits (int, optional) – The number of bits to quantize the model to.

  • scan_layers (bool, optional, defaults to False) – Whether to use the scan implementation of the layers.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, bits: Optional[int] = None, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the Transformer class:

Parameters
  • self – Refer to the current object

  • gradient_checkpointing – str: Control the amount of memory used by jax

  • bits – tp.Optional[int]: Determine the number of bits used in the quantization

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#
property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
model_type: str = 'xerxes'#
static rng_keys()[source]#
class easydel.__init__.XerxesForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.__init__.XerxesModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

easydel.__init__.get_modules_by_type(model_type: str, task_type: TaskType) Tuple[Type[EasyDeLBaseConfig], Union[Type[EasyDeLBaseModule], Any]][source]#
The get_modules_by_type function is a helper function that returns the following:
  1. The config class for the model type specified (e.g., LlamaConfig, FalconConfig)

  2. The EasyDeL Model class for the model type specified (e.g., FlaxLlamaForCausalLM, FalconForCausalLM)

easydel.__init__.module_to_huggingface_model(module: ~typing.Any, config: ~typing.Any, base_huggingface_module: ~typing.Any, base_huggingface_module_kwarguments: ~typing.Optional[~typing.Dict] = None, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>, use_meta_torch: bool = True, **kw)[source]#
easydel.__init__.module_to_torch(module: ~typing.Any, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>)[source]#
easydel.__init__.pack_sequences(dataset: Any, max_length: int = 512, pad_token_id: int = 0, reset_position_ids: bool = False, num_proc: Optional[int] = None)[source]#

Pack sequences together with their attention masks and position IDs

# With continuous position IDs packed_dataset = pack_sequences(

dataset, max_length=512, pad_token_id=0, reset_position_ids=False

)

# With reset position IDs for each sequence packed_dataset = pack_sequences(

dataset, max_length=512, pad_token_id=0, reset_position_ids=True

)

# Example output format for a packed sequence with two sequences: # reset_position_ids=False: {

‘input_ids’: [seq1_tokens + [PAD] + seq2_tokens + [PAD] + padding], ‘attention_mask’: [1,1,1,0,1,1,1,0,0,0], ‘position_ids’: [0,1,2,3,4,5,6,7,0,0]

}

# reset_position_ids=True: {

‘input_ids’: [seq1_tokens + [PAD] + seq2_tokens + [PAD] + padding], ‘attention_mask’: [1,1,1,0,1,1,1,0,0,0], ‘position_ids’: [0,1,2,0,0,1,2,0,0,0]

}

Parameters
  • dataset – Dataset containing ‘input_ids’ and ‘attention_mask’

  • max_length – Maximum length of packed sequence

  • pad_token_id – Token ID used for padding

  • reset_position_ids – If True, reset position IDs for each sequence in the pack

Returns

Dataset with packed sequences, attention masks, and position IDs

Return type

packed_dataset

easydel.__init__.register_config(config_type: str, config_field: ConfigType = ConfigType.MODULE_CONFIG) callable#

Register a configuration class.

Parameters
  • config_type – Identifier for the configuration

  • config_field – Type of configuration registry

Returns

Decorator function

easydel.__init__.register_module(task_type: TaskType, config: EasyDeLBaseConfig, model_type: str, embedding_layer_names: Optional[List[str]] = None, layernorm_names: Optional[List[str]] = None) callable#

Register a module for a specific task.

Parameters
  • task_type – Type of task

  • config – Configuration for the module

  • model_type – Identifier for the model

  • embedding_layer_names – Names of embedding layers

  • layernorm_names – Names of layer normalization layers

Returns

Decorator function

easydel.__init__.torch_dict_to_easydel_params(state_dict: ~typing.Dict[str, ~typing.Any], *, device: ~typing.Optional[~jaxlib.xla_extension.Device] = None, embedding_layer_names: ~typing.Optional[~typing.List[str]] = None, layernorm_names: ~typing.Optional[~typing.List[str]] = None, shard_fns: ~typing.Optional[~typing.Mapping[tuple, ~typing.Callable]] = None, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>, verbose: bool = True, callback: ~typing.Optional[~typing.Callable[[~jax.Array, tuple], ~jax.Array]] = None, remove_state_dict: bool = False, lm_head_name: ~typing.Optional[str] = None, uses_tie_word_embedding: bool = False, **kwargs) Dict[str, Any][source]#

Convert PyTorch state dict to EasyDel parameter format.

Parameters
  • state_dict – PyTorch state dictionary

  • device – JAX device to use

  • embedding_layer_names – Names of embedding layers

  • layernorm_names – Names of layer normalization layers

  • shard_fns – tp.Mapping of parameter names to sharding functions

  • block_size – Size of processing blocks

  • params_pattern_selection – Regex pattern for parameter selection

  • dtype – Target dtype for parameters

  • verbose – Whether to show progress bar callback: callback for tensors after they are converted to a jax array.

  • remove_state_dict – Whether to delete state_dict after conversion

  • lm_head_name – Name of language model head

  • uses_tie_word_embedding – Whether model uses tied embeddings

  • **kwargs – Additional arguments

Returns

Dictionary of converted parameters in EasyDel format

class easydel.__init__.vInference(model: None, processor_class: None, generation_config: Optional[vInferenceConfig] = None, seed: Optional[int] = None, input_partition_spec: Optional[PartitionSpec] = None, max_new_tokens: int = 512, inference_name: Optional[str] = None)[source]#

Bases: object

Class for performing text generation using a pre-trained language graphdef in EasyDeL.

This class handles the generation process, including initialization, precompilation, and generating text in streaming chunks.

property SEQUENCE_DIM_MAPPING#
count_tokens(messages: List[Dict[str, str]])[source]#
count_tokens(text: str)
generate(input_ids: Array, attention_mask: Optional[Array] = None, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, **model_kwargs) Generator[Union[SampleState, Any], SampleState, SampleState][source]#

Generates text in streaming chunks with comprehensive input adjustment.

Parameters
  • input_ids – Input token IDs as a JAX array

  • attention_mask – Optional attention mask for the input

  • graphstate (nn.GraphState, optional) – in case that you want to update model state for generation.

  • graphother (nn.GraphState, optional) – in case that you want to update model ostate for generation.

  • **model_kwargs – Additional model-specific keyword arguments

Returns

Generator yielding SampleState objects containing generation results and metrics

property inference_name#
classmethod load_inference(path: Union[PathLike, str], model: None, processor_class: None)[source]#
property metrics#
property model#
property model_prefill_length: int#

Calculate the maximum length available for input prefill by subtracting the maximum new tokens from the model’s maximum sequence length.

Returns

The maximum length available for input prefill

Return type

int

Raises

ValueError – If no maximum sequence length configuration is found

precompile(config: vInferencePreCompileConfig)[source]#

Precompiles the generation functions for a given batch size and input length.

This function checks if the generation functions have already been compiled for the given configuration. If not, it compiles them asynchronously and stores them in a cache.

Returns

True if precompilation was successful, False otherwise.

Return type

bool

save_inference(path: Union[PathLike, str])[source]#
property tokenizer#
class easydel.__init__.vInferenceApiServer(inference_map: Union[Dict[str, Any], Any] = None, inference_init_call: Optional[Callable[[], Any]] = None, max_workers: int = 10)[source]#

Bases: object

available_inference()[source]#
async chat_completions(request: ChatCompletionRequest)[source]#
count_tokens(request: CountTokenRequest)[source]#
fire(host='0.0.0.0', port=11556, metrics_port: Optional[int] = None, log_level='debug')[source]#
liveness()[source]#
patch_endpoints()[source]#

Register all endpoints with the FastAPI app.

readiness()[source]#
class easydel.__init__.vInferenceConfig(max_new_tokens: int = 64, min_length: Optional[int] = None, streaming_chunks: int = 16, temperature: float = 0.0, top_p: float = 0.95, top_k: int = 50, do_sample: bool = True, no_repeat_ngram_size: Optional[int] = None, num_return_sequences: Union[int, Dict[int, int], NoneType] = 1, suppress_tokens: Optional[list] = None, forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, bos_token_id: Optional[int] = None, eos_token_id: Union[int, List[int], NoneType] = None, partition_rules: Optional[Tuple[Tuple[str, Any]]] = None, partition_axis: Optional[eformer.escale.partition.constraints.PartitionAxis] = None, _loop_rows: Optional[int] = None)[source]#

Bases: object

bos_token_id: Optional[int] = None#
do_sample: bool = True#
eos_token_id: Optional[Union[int, List[int]]] = None#
forced_bos_token_id: Optional[int] = None#
forced_eos_token_id: Optional[int] = None#
get_logits_processor()[source]#
get_logits_warper()[source]#
get_partition_rules(runtime_config: Optional[vInferencePreCompileConfig] = None)[source]#
max_new_tokens: int = 64#
min_length: Optional[int] = None#
no_repeat_ngram_size: Optional[int] = None#
num_return_sequences: Optional[Union[int, Dict[int, int]]] = 1#
pad_token_id: Optional[int] = None#
partition_axis: Optional[PartitionAxis] = None#
partition_rules: Optional[Tuple[Tuple[str, Any]]] = None#
replace(**kwargs)#
streaming_chunks: int = 16#
suppress_tokens: Optional[list] = None#
temperature: float = 0.0#
top_k: int = 50#
top_p: float = 0.95#
class easydel.__init__.vInferencePreCompileConfig(batch_size: Union[int, List[int]] = 1, prefill_length: Union[int, List[int], NoneType] = None, vision_included: Union[bool, List[bool]] = False, vision_batch_size: Union[int, List[int], NoneType] = None, vision_channels: Union[int, List[int], NoneType] = None, vision_height: Union[int, List[int], NoneType] = None, vision_width: Union[int, List[int], NoneType] = None, required_props: Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]], NoneType] = None)[source]#

Bases: object

batch_size: Union[int, List[int]] = 1#
get_default_hash()[source]#
get_standalones()[source]#

Creates standalone configurations when any field contains a list. Returns a list of standalone vInferencePreCompileConfig instances.

For example, if batch_size=[1, 2, 3, 4], it will create 4 standalone configs with batch_size values 1, 2, 3, and 4 respectively.

prefill_length: Optional[Union[int, List[int]]] = None#
replace(**kwargs)#
required_props: Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]] = None#
vision_batch_size: Optional[Union[int, List[int]]] = None#
vision_channels: Optional[Union[int, List[int]]] = None#
vision_height: Optional[Union[int, List[int]]] = None#
vision_included: Union[bool, List[bool]] = False#
vision_width: Optional[Union[int, List[int]]] = None#
class easydel.__init__.vWhisperInference(model: ~typing.Any, tokenizer: ~typing.Any, processor: ~typing.Any, inference_config: ~typing.Optional[~easydel.__init__.inference.whisper_inference.vWhisperInferenceConfig] = None, dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>)[source]#

Bases: object

Whisper inference pipeline for performing speech-to-text transcription or translation.

Parameters
  • model (WhisperForConditionalGeneration) – The fine-tuned Whisper model to use for inference.

  • tokenizer (WhisperTokenizer) – Tokenizer for Whisper.

  • processor (WhisperProcessor) – Processor for Whisper.

  • inference_config (vWhisperInferenceConfig, optional) – Inference configuration.

  • dtype (jax.typing.DTypeLike, optional, defaults to jnp.float32) – Data type for computations.

Example usage:

>>> import easydel as ed
>>> from transformers import WhisperTokenizer, WhisperProcessor
>>> REPO_ID = "openai/whisper-small"  # Replace with your desired model
>>> model = ed.AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained(
...             REPO_ID,
...             # ... (config_kwargs as needed)
>>> )
>>> tokenizer = WhisperTokenizer.from_pretrained(REPO_ID)
>>> processor = WhisperProcessor.from_pretrained(REPO_ID)
>>> inference = vWhisperInference(
...             model=model,
...             tokenizer=tokenizer,
...             processor=processor,
...             dtype=jnp.float16,  # Or jnp.float32
>>> )
>>> result = inference("sample1.flac", return_timestamps=True)
>>> print(result)
>>> # Example using a URL:
>>> result_url = inference(
...             "https://huggingface.co/datasets/hf-internal-testing/librispeech_asr_dummy/raw/main/common_voice_en_100038.mp3",
...             return_timestamps=True
>>> )
>>> print(result_url)
>>> # Example specifying language and task:
>>> result_lang_task = inference(
...             "sample1.flac", language="en", task="transcribe", return_timestamps=True
>>> )
>>> print(result_lang_task)
chunk_iter_with_batch(audio_array: Array, chunk_length: int, stride_left: int, stride_right: int, batch_size: int)[source]#
generate(audio_input: Union[str, bytes, ndarray, Dict[str, Union[ndarray, int]]], chunk_length_s: float = 30.0, stride_length_s: Optional[Union[float, list[float]]] = None, batch_size: Optional[int] = None, language: Optional[str] = None, task: Optional[str] = None, return_timestamps: Optional[bool] = None)[source]#

Transcribe or translate audio input.

Parameters
  • audio_input (tp.Union[str, bytes, np.ndarray, tp.Dict[str, tp.Union[np.ndarray, int]]]) – Input audio. Can be a local file path, URL, bytes, numpy array, or a dictionary containing the array and sampling rate.

  • chunk_length_s (float, optional, defaults to 30.0) – Length of audio chunks in seconds.

  • stride_length_s (float or list[float], optional) – Stride length for chunking audio, in seconds. Defaults to chunk_length_s / 6.

  • batch_size (int, optional) – Batch size for processing. Defaults to the batch_size in inference_config.

  • language (str, optional) – Language of the input audio. Defaults to the language in inference_config.

  • task (str, optional) – Task to perform (e.g., “transcribe”, “translate”). Defaults to the task in inference_config.

  • return_timestamps (bool, optional) – Whether to return timestamps with the transcription. Defaults to the return_timestamps in inference_config.

Returns

A dictionary containing the transcribed text (“text”) and optionally other information like timestamps or detected language.

Return type

dict

get_decoder_input_ids(generation_config: Optional[Any] = None, task: Optional[str] = None, language: Optional[str] = None, return_timestamps: bool = False) list[Tuple[int, int]][source]#
class easydel.__init__.vWhisperInferenceConfig(batch_size: Optional[int] = 1, max_length: Optional[int] = None, generation_config: Optional[Any] = None)[source]#

Bases: object

Configuration class for Whisper inference.

Parameters
  • batch_size (int, optional, defaults to 1) – Batch size used for inference.

  • max_length (int, optional) – Maximum sequence length for generation.

  • generation_config (transformers.GenerationConfig, optional) – Generation configuration object.

  • logits_processor (optional) – Not used.

  • return_timestamps (bool, optional) – Whether to return timestamps with the transcribed text.

  • task (str, optional) – Task for the model (e.g., “transcribe”, “translate”).

  • language (str, optional) – Language of the input audio.

  • is_multilingual (bool, optional) – Whether the model is multilingual.

batch_size: Optional[int] = 1#
generation_config: Optional[Any] = None#
is_multilingual = None#
language = None#
logits_processor = None#
max_length: Optional[int] = None#
replace(**kwargs)#
return_timestamps = None#
task = None#