easydel.modules.phimoe.modeling_phimoe_flax#
- class easydel.modules.phimoe.modeling_phimoe_flax.PhiMoEAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModulePhiMoE Attention module.
This module implements the multi-head attention mechanism used in the PhiMoE model, which is similar to the one in Phi-3. It supports Grouped Query Attention (GQA) and Rotary Position Embeddings (RoPE), including scaling options.
- config#
Configuration object for the model.
- Type
- layer_idx#
Index of the current layer.
- Type
int
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- attention_dropout#
Dropout probability for attention scores.
- Type
float
Dimensionality of the hidden states.
- Type
int
- num_heads#
Number of attention query heads.
- Type
int
- head_dim#
Dimensionality of each attention head.
- Type
int
- num_key_value_heads#
Number of attention key/value heads (for GQA).
- Type
int
- num_key_value_groups#
Number of query head groups for each key/value head.
- Type
int
- max_position_embeddings#
Maximum sequence length supported by RoPE.
- Type
int
- original_max_position_embeddings#
Original max sequence length for RoPE scaling.
- Type
int
- rope_theta#
Base value for RoPE frequency calculation.
- Type
float
- rope_scaling#
Configuration for RoPE scaling.
- Type
dict
- is_causal#
Whether the attention is causal (always True for this implementation).
- Type
bool
- q_proj#
Linear layer for query projection.
- Type
- k_proj#
Linear layer for key projection.
- Type
- v_proj#
Linear layer for value projection.
- Type
- o_proj#
Linear layer for the output projection.
- Type
- attention_performer#
Module to perform the core attention computation.
- rotary#
Rotary position embedding module.
- Type
RoPE
- class easydel.modules.phimoe.modeling_phimoe_flax.PhiMoEBlockSparseTop2MLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePhiMoE Block Sparse Top-2 MLP module.
This module implements the feed-forward network (MLP) for a single expert in the PhiMoE model’s Mixture of Experts layer. It uses a Gated Linear Unit (GLU) structure with SiLU activation.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- w1#
First linear layer (part of the GLU gate).
- Type
- w2#
Second linear layer (down-projection).
- Type
- w3#
Third linear layer (part of the GLU value).
- Type
- act_fn#
Activation function (SiLU).
- Type
callable
- class easydel.modules.phimoe.modeling_phimoe_flax.PhiMoeDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePhiMoE Transformer Decoder Layer.
This module represents a single decoder layer in the PhiMoE model. It combines self-attention and a Sparse Mixture of Experts (MoE) block (or a standard MLP if not an MoE layer) with residual connections and RMS normalization.
- config#
Configuration object for the model.
- Type
- layer_idx#
Index of the current layer.
- Type
int
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- self_attn#
The self-attention module.
- Type
- mlp#
The Sparse MoE block.
- Type
- post_attention_layernorm#
RMS normalization applied after the attention layer and before the MoE block.
- Type
- dropout#
Dropout layer (potentially unused, dropout is often handled within submodules).
- Type
nn.Dropout
- class easydel.modules.phimoe.modeling_phimoe_flax.PhiMoeForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModulePhiMoE model with a Causal Language Modeling head.
This model consists of the base PhiMoE transformer (PhiMoeModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core PhiMoE transformer model.
- Type
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.modules.phimoe.modeling_phimoe_flax.PhiMoeModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base PhiMoE model transformer.
This class represents the core transformer architecture of the PhiMoE model, consisting of an embedding layer, multiple PhiMoeDecoderLayer layers (which include Sparse Mixture of Experts blocks), and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[PhiMoeDecoderLayer]
- embed_dropout#
Dropout layer applied after embeddings.
- Type
nn.Dropout
- gradient_checkpointing#
Gradient checkpointing configuration.
- class easydel.modules.phimoe.modeling_phimoe_flax.PhiMoeSparseMoeBlock(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePhiMoE Sparse Mixture of Experts (MoE) Block.
This module implements the core MoE logic, including the router (gate) and the expert layers. It routes each token to the top-k experts based on the router logits and combines the expert outputs.
- config#
Configuration object for the model.
- Type
- layer_idx#
Index of the current layer.
- Type
int
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- gate#
Linear layer for the router gate.
- Type
- experts#
List of expert MLP modules.
- Type
tp.List[PhiMoEBlockSparseTop2MLP]