easydel.modules.phimoe.modeling_phimoe#
- class easydel.modules.phimoe.modeling_phimoe.PhiMoEAttention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionPhiMoE attention powered by UnifiedAttention with optional sharding constraint.
- class easydel.modules.phimoe.modeling_phimoe.PhiMoEBlockSparseTop2MLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePhiMoE Block Sparse Top-2 MLP module.
This module implements the feed-forward network (MLP) for a single expert in the PhiMoE model’s Mixture of Experts layer. It uses a Gated Linear Unit (GLU) structure with SiLU activation.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- w1#
First linear layer (part of the GLU gate).
- Type
- w2#
Second linear layer (down-projection).
- Type
- w3#
Third linear layer (part of the GLU value).
- Type
- act_fn#
Activation function (SiLU).
- Type
callable
- class easydel.modules.phimoe.modeling_phimoe.PhiMoeDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePhiMoE Transformer Decoder Layer.
This module represents a single decoder layer in the PhiMoE model. It combines self-attention and a Sparse Mixture of Experts (MoE) block (or a standard MLP if not an MoE layer) with residual connections and RMS normalization.
- config#
Configuration object for the model.
- Type
- layer_idx#
Index of the current layer.
- Type
int
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- self_attn#
The self-attention module.
- Type
- mlp#
The Sparse MoE block.
- Type
- post_attention_layernorm#
RMS normalization applied after the attention layer and before the MoE block.
- Type
- dropout#
Dropout layer (potentially unused, dropout is often handled within submodules).
- Type
nn.Dropout
- class easydel.modules.phimoe.modeling_phimoe.PhiMoeForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModulePhiMoE model with a Causal Language Modeling head.
This model consists of the base PhiMoE transformer (PhiMoeModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core PhiMoE transformer model.
- Type
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.modules.phimoe.modeling_phimoe.PhiMoeModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base PhiMoE model transformer.
This class represents the core transformer architecture of the PhiMoE model, consisting of an embedding layer, multiple PhiMoeDecoderLayer layers (which include Sparse Mixture of Experts blocks), and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[PhiMoeDecoderLayer]
- embed_dropout#
Dropout layer applied after embeddings.
- Type
nn.Dropout
- gradient_checkpointing#
Gradient checkpointing configuration.
- class easydel.modules.phimoe.modeling_phimoe.PhiMoeSparseMoeBlock(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePhiMoE Sparse Mixture of Experts (MoE) Block.
This module implements the core MoE logic, including the router (gate) and the expert layers. It routes each token to the top-k experts based on the router logits and combines the expert outputs.
- config#
Configuration object for the model.
- Type
- layer_idx#
Index of the current layer.
- Type
int
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- gate#
Linear layer for the router gate.
- Type
- experts#
List of expert MLP modules.
- Type
tp.List[PhiMoEBlockSparseTop2MLP]