easydel.modules.phimoe.modeling_phimoe

Contents

easydel.modules.phimoe.modeling_phimoe#

class easydel.modules.phimoe.modeling_phimoe.PhiMoEAttention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

PhiMoE attention powered by UnifiedAttention with optional sharding constraint.

class easydel.modules.phimoe.modeling_phimoe.PhiMoEBlockSparseTop2MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

PhiMoE Block Sparse Top-2 MLP module.

This module implements the feed-forward network (MLP) for a single expert in the PhiMoE model’s Mixture of Experts layer. It uses a Gated Linear Unit (GLU) structure with SiLU activation.

config#

Configuration object for the model.

Type

PhiMoeConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

w1#

First linear layer (part of the GLU gate).

Type

ParallelLinear

w2#

Second linear layer (down-projection).

Type

ParallelLinear

w3#

Third linear layer (part of the GLU value).

Type

ParallelLinear

act_fn#

Activation function (SiLU).

Type

callable

class easydel.modules.phimoe.modeling_phimoe.PhiMoeDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

PhiMoE Transformer Decoder Layer.

This module represents a single decoder layer in the PhiMoE model. It combines self-attention and a Sparse Mixture of Experts (MoE) block (or a standard MLP if not an MoE layer) with residual connections and RMS normalization.

config#

Configuration object for the model.

Type

PhiMoeConfig

layer_idx#

Index of the current layer.

Type

int

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

input_layernorm#

RMS normalization applied before the attention layer.

Type

RMSNorm

self_attn#

The self-attention module.

Type

PhiMoEAttention

mlp#

The Sparse MoE block.

Type

PhiMoeSparseMoeBlock

post_attention_layernorm#

RMS normalization applied after the attention layer and before the MoE block.

Type

RMSNorm

dropout#

Dropout layer (potentially unused, dropout is often handled within submodules).

Type

nn.Dropout

class easydel.modules.phimoe.modeling_phimoe.PhiMoeForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

PhiMoE model with a Causal Language Modeling head.

This model consists of the base PhiMoE transformer (PhiMoeModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.

config#

Configuration object for the model.

Type

PhiMoeConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core PhiMoE transformer model.

Type

PhiMoeModel

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module.

class easydel.modules.phimoe.modeling_phimoe.PhiMoeModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base PhiMoE model transformer.

This class represents the core transformer architecture of the PhiMoE model, consisting of an embedding layer, multiple PhiMoeDecoderLayer layers (which include Sparse Mixture of Experts blocks), and a final RMS normalization layer.

config#

Configuration object for the model.

Type

PhiMoeConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[PhiMoeDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

embed_dropout#

Dropout layer applied after embeddings.

Type

nn.Dropout

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.

class easydel.modules.phimoe.modeling_phimoe.PhiMoeSparseMoeBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

PhiMoE Sparse Mixture of Experts (MoE) Block.

This module implements the core MoE logic, including the router (gate) and the expert layers. It routes each token to the top-k experts based on the router logits and combines the expert outputs.

config#

Configuration object for the model.

Type

PhiMoeConfig

layer_idx#

Index of the current layer.

Type

int

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

gate#

Linear layer for the router gate.

Type

ParallelLinear

experts#

List of expert MLP modules.

Type

tp.List[PhiMoEBlockSparseTop2MLP]