easydel.modules.xerxes2.modeling_xerxes2#

class easydel.modules.xerxes2.modeling_xerxes2.Xerxes2Attention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

Xerxes2 Multi-head Latent Attention.

Inherits MLA implementation from UnifiedAttention base class. Uses a compressed KV representation with LoRA and separate nope/rope dimensions.

define_network(config: Xerxes2Config, dtype: dtype, param_dtype: dtype, precision: Precision, rngs: Rngs)[source]#

Define MLA-specific network structure for Xerxes2.

projection_mapping: ClassVar[dict[str, str]] = {'mla_kv_a_layernorm': 'kv_a_layernorm', 'mla_kv_a_proj_with_mqa': 'kv_a_proj_with_mqa', 'mla_kv_b_proj': 'kv_b_proj', 'mla_q_a_layernorm': 'q_a_layernorm', 'mla_q_a_proj': 'q_a_proj', 'mla_q_b_proj': 'q_b_proj', 'mla_q_proj': 'q_proj', 'output_projection': 'o_proj'}#
class easydel.modules.xerxes2.modeling_xerxes2.Xerxes2DecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Transformer decoder layer with Xerxes2 attention and optional MoE MLP.

class easydel.modules.xerxes2.modeling_xerxes2.Xerxes2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[Xerxes2Model, Xerxes2Config]

Xerxes2 model with a language modeling head for causal language modeling tasks.

This model extends the base Xerxes2Model by adding a linear language modeling head on top of the transformer model. It incorporates Mixture of Experts (MoE) architecture and is designed for generative tasks and text generation.

create_cache_metadata(batch_size: int, max_length: int, pad_token_id: int | None = None)[source]#

Creates the metadata required for initializing a standard (non-paged) KV Cache.

This method gathers parameters like layer count, head dimensions, and determines the appropriate padding token ID to instantiate and return a TransformerCacheMetaData object suitable for a standard sequential KV cache.

Parameters
  • batch_size (int) – The batch size for which the cache is being configured.

  • max_length (int) – The maximum sequence length the cache needs to support.

  • pad_token_id (int | None) – The ID of the padding token. If None, it attempts to find it from self.generation_config or self.config, defaulting to 0.

Returns

An initialized metadata object for a standard KV cache.

Return type

TransformerCacheMetaData

init_cache(batch_size: int, max_length: int, starts: int | None = None, shardings: dict | None = None, pad_token_id: int | None = None)[source]#

Initializes and returns a standard (non-paged) Key-Value cache.

This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the model’s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.

Parameters
  • batch_size (int) – The batch size for the cache.

  • max_length (int) – The maximum sequence length the cache needs to support.

  • starts (int | None) – Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).

  • shardings (dict | None) – Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).

  • pad_token_id (int | None) – The ID of the padding token. If None, it’s inferred.

Returns

An initialized standard TransformerCache object.

Return type

TransformerCache

class easydel.modules.xerxes2.modeling_xerxes2.Xerxes2MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Feed-forward network used in dense Xerxes2 decoder layers.

class easydel.modules.xerxes2.modeling_xerxes2.Xerxes2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Xerxes2 decoder-only stack connecting embeddings, decoder layers, and final norm.

property frequencies: Array#

Returns frequency values from the config.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.

class easydel.modules.xerxes2.modeling_xerxes2.Xerxes2MoeMLPStack(*args: Any, **kwargs: Any)[source]#

Bases: Module

Xerxes2Moe MoE MLP using the new ParallelMoELinear layers.

class easydel.modules.xerxes2.modeling_xerxes2.Xerxes2MoeSparseBlock(*args: Any, **kwargs: Any)[source]#

Bases: BaseMoeModule

Sparse Mixture of Experts (MoE) block for Xerxes2 MoE.

This block routes input hidden states to a selected subset of experts and combines their outputs.

config#

Configuration object for the model.

Type

Xerxes2MoeConfig

gate#

Linear layer for the gating network.

Type

ParallelLinear

experts#

List of expert MLP modules.

Type

nn.List[Xerxes2MoeMLP]

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs