easydel.modules.xerxes2.modeling_xerxes2#
- class easydel.modules.xerxes2.modeling_xerxes2.Xerxes2Attention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionXerxes2 Multi-head Latent Attention.
Inherits MLA implementation from UnifiedAttention base class. Uses a compressed KV representation with LoRA and separate nope/rope dimensions.
- define_network(config: Xerxes2Config, dtype: dtype, param_dtype: dtype, precision: Precision, rngs: Rngs)[source]#
Define MLA-specific network structure for Xerxes2.
- projection_mapping: ClassVar[dict[str, str]] = {'mla_kv_a_layernorm': 'kv_a_layernorm', 'mla_kv_a_proj_with_mqa': 'kv_a_proj_with_mqa', 'mla_kv_b_proj': 'kv_b_proj', 'mla_q_a_layernorm': 'q_a_layernorm', 'mla_q_a_proj': 'q_a_proj', 'mla_q_b_proj': 'q_b_proj', 'mla_q_proj': 'q_proj', 'output_projection': 'o_proj'}#
- class easydel.modules.xerxes2.modeling_xerxes2.Xerxes2DecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleTransformer decoder layer with Xerxes2 attention and optional MoE MLP.
- class easydel.modules.xerxes2.modeling_xerxes2.Xerxes2ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
BaseCausalLMModule[Xerxes2Model,Xerxes2Config]Xerxes2 model with a language modeling head for causal language modeling tasks.
This model extends the base Xerxes2Model by adding a linear language modeling head on top of the transformer model. It incorporates Mixture of Experts (MoE) architecture and is designed for generative tasks and text generation.
- create_cache_metadata(batch_size: int, max_length: int, pad_token_id: int | None = None)[source]#
Creates the metadata required for initializing a standard (non-paged) KV Cache.
This method gathers parameters like layer count, head dimensions, and determines the appropriate padding token ID to instantiate and return a TransformerCacheMetaData object suitable for a standard sequential KV cache.
- Parameters
batch_size (int) – The batch size for which the cache is being configured.
max_length (int) – The maximum sequence length the cache needs to support.
pad_token_id (int | None) – The ID of the padding token. If None, it attempts to find it from self.generation_config or self.config, defaulting to 0.
- Returns
An initialized metadata object for a standard KV cache.
- Return type
- init_cache(batch_size: int, max_length: int, starts: int | None = None, shardings: dict | None = None, pad_token_id: int | None = None)[source]#
Initializes and returns a standard (non-paged) Key-Value cache.
This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the model’s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.
- Parameters
batch_size (int) – The batch size for the cache.
max_length (int) – The maximum sequence length the cache needs to support.
starts (int | None) – Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).
shardings (dict | None) – Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).
pad_token_id (int | None) – The ID of the padding token. If None, it’s inferred.
- Returns
An initialized standard TransformerCache object.
- Return type
- class easydel.modules.xerxes2.modeling_xerxes2.Xerxes2MLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleFeed-forward network used in dense Xerxes2 decoder layers.
- class easydel.modules.xerxes2.modeling_xerxes2.Xerxes2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleXerxes2 decoder-only stack connecting embeddings, decoder layers, and final norm.
- class easydel.modules.xerxes2.modeling_xerxes2.Xerxes2MoeMLPStack(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleXerxes2Moe MoE MLP using the new ParallelMoELinear layers.
- class easydel.modules.xerxes2.modeling_xerxes2.Xerxes2MoeSparseBlock(*args: Any, **kwargs: Any)[source]#
Bases:
BaseMoeModuleSparse Mixture of Experts (MoE) block for Xerxes2 MoE.
This block routes input hidden states to a selected subset of experts and combines their outputs.
- config#
Configuration object for the model.
- Type
Xerxes2MoeConfig
- gate#
Linear layer for the gating network.
- Type
- experts#
List of expert MLP modules.
- Type
nn.List[Xerxes2MoeMLP]
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs