easydel.modules.openelm.__init__#

class easydel.modules.openelm.__init__.OpenELMConfig(vocab_size: int = 32000, max_context_length: int = 2048, num_transformer_layers: int = 12, model_dim: int = 2048, head_dim: int = 128, qkv_multipliers: Union[Number, List[Number]] = 1.0, num_query_heads: Optional[int] = None, num_gqa_groups: int = 1, ffn_multipliers: Union[Number, List[Number]] = 4.0, ffn_with_glu: bool = True, ffn_dim_divisor: int = 256, activation_fn_name: str = 'swish', normalization_layer_name: str = 'rms_norm', normalize_qk_projections: bool = False, share_input_output_layers: bool = False, rope_freq_constant: int = 10000, rope_max_length: int = 4096, initializer_range: float = 0.02, use_cache: bool = True, bos_token_id: int = 1, eos_token_id: int = 2, rope_scaling: Dict[str, Union[str, float]] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: Optional[int] = None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the OpenELM model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • max_context_length (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • num_transformer_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.

  • model_dim (int, optional, defaults to 2048) – Dimensionality of the encoder layers and the pooler layer.

  • head_dim (int, optional, defaults to 128) – Dimensionality of the attention heads.

  • qkv_multipliers (float or list of float, optional, defaults to 1.0) – The multiplier for the query, key, and value projections.

  • num_query_heads (int, optional) – Number of query heads. If not provided, it will be calculated based on model_dim and head_dim.

  • num_gqa_groups (int, optional, defaults to 1) – Number of GQA (Grouped Query Attention) groups.

  • ffn_multipliers (float or list of float, optional, defaults to 4.0) – The multiplier for the feed-forward network.

  • ffn_with_glu (bool, optional, defaults to True) – Whether to use a gated linear unit (GLU) in the feed-forward network.

  • ffn_dim_divisor (int, optional, defaults to 256) – The divisor for the feed-forward network dimension.

  • activation_fn_name (str, optional, defaults to “swish”) – The activation function to use.

  • normalization_layer_name (str, optional, defaults to “rms_norm”) – The normalization layer to use.

  • normalize_qk_projections (bool, optional, defaults to False) – Whether to normalize the query and key projections.

  • share_input_output_layers (bool, optional, defaults to False) – Whether to share the input and output layers.

  • rope_freq_constant (int, optional, defaults to 10000) – The frequency constant for Rotary Position Embeddings (RoPE).

  • rope_max_length (int, optional, defaults to 4096) – The maximum length for RoPE.

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • bits (int, optional) – The number of bits to quantize the model to.

attribute_map: dict[str, str] = {'tie_word_embedding': 'share_input_output_layers'}#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.

Parameters
  • *args – Additional positional arguments (unused).

  • **kwargs – Additional keyword arguments (unused).

Returns

A tuple of partition rules, where each rule is a tuple

containing a regex pattern for parameter names and the corresponding PartitionSpec.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#

Returns a tuple of parameter names for which weight decay should be excluded.

Returns

A tuple containing ‘bias’, ‘normalization’, and ‘emb’ as exclusions.

Return type

tuple

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size specifically for frequency-based position embeddings.

If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_context_length.

Returns

The granted maximum position embedding size for frequency encoding.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size specifically for mask-based position embeddings.

If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_context_length.

Returns

The granted maximum position embedding size for mask encoding.

Return type

int

model_type: str = 'openelm'#
static rng_keys()[source]#

Returns the names of the random number generator keys used by the model.

Returns

A tuple containing “params” and “dropout” as the RNG keys.

Return type

tuple

class easydel.modules.openelm.__init__.OpenELMForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

OpenELM model with a Causal Language Modeling head.

This model consists of the base OpenELM transformer (OpenELMModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.

config#

Configuration object for the model.

Type

OpenELMConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

transformer#

The core OpenELM transformer model.

Type

OpenELMModel

lm_head#

The linear layer for projecting hidden states to vocabulary logits. This is None if config.share_input_output_layers is True.

Type

ParallelLinear, optional

class easydel.modules.openelm.__init__.OpenELMModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base OpenELM model transformer.

This class represents the core transformer architecture of the OpenELM model, consisting of an embedding layer, multiple OpenELMDecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

OpenELMConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

token_embeddings#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[OpenELMDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray