easydel.modules.phimoe.modeling_phimoe_flax

Contents

easydel.modules.phimoe.modeling_phimoe_flax#

class easydel.modules.phimoe.modeling_phimoe_flax.PhiMoEAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

PhiMoE Attention module.

This module implements the multi-head attention mechanism used in the PhiMoE model, which is similar to the one in Phi-3. It supports Grouped Query Attention (GQA) and Rotary Position Embeddings (RoPE), including scaling options.

config#

Configuration object for the model.

Type

PhiMoeConfig

layer_idx#

Index of the current layer.

Type

int

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

attention_dropout#

Dropout probability for attention scores.

Type

float

hidden_size#

Dimensionality of the hidden states.

Type

int

num_heads#

Number of attention query heads.

Type

int

head_dim#

Dimensionality of each attention head.

Type

int

num_key_value_heads#

Number of attention key/value heads (for GQA).

Type

int

num_key_value_groups#

Number of query head groups for each key/value head.

Type

int

max_position_embeddings#

Maximum sequence length supported by RoPE.

Type

int

original_max_position_embeddings#

Original max sequence length for RoPE scaling.

Type

int

rope_theta#

Base value for RoPE frequency calculation.

Type

float

rope_scaling#

Configuration for RoPE scaling.

Type

dict

is_causal#

Whether the attention is causal (always True for this implementation).

Type

bool

q_proj#

Linear layer for query projection.

Type

ParallelLinear

k_proj#

Linear layer for key projection.

Type

ParallelLinear

v_proj#

Linear layer for value projection.

Type

ParallelLinear

o_proj#

Linear layer for the output projection.

Type

ParallelLinear

attention_performer#

Module to perform the core attention computation.

Type

FlexibleAttentionModule

rotary#

Rotary position embedding module.

Type

RoPE

class easydel.modules.phimoe.modeling_phimoe_flax.PhiMoEBlockSparseTop2MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

PhiMoE Block Sparse Top-2 MLP module.

This module implements the feed-forward network (MLP) for a single expert in the PhiMoE model’s Mixture of Experts layer. It uses a Gated Linear Unit (GLU) structure with SiLU activation.

config#

Configuration object for the model.

Type

PhiMoeConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

w1#

First linear layer (part of the GLU gate).

Type

ParallelLinear

w2#

Second linear layer (down-projection).

Type

ParallelLinear

w3#

Third linear layer (part of the GLU value).

Type

ParallelLinear

act_fn#

Activation function (SiLU).

Type

callable

class easydel.modules.phimoe.modeling_phimoe_flax.PhiMoeDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

PhiMoE Transformer Decoder Layer.

This module represents a single decoder layer in the PhiMoE model. It combines self-attention and a Sparse Mixture of Experts (MoE) block (or a standard MLP if not an MoE layer) with residual connections and RMS normalization.

config#

Configuration object for the model.

Type

PhiMoeConfig

layer_idx#

Index of the current layer.

Type

int

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

input_layernorm#

RMS normalization applied before the attention layer.

Type

RMSNorm

self_attn#

The self-attention module.

Type

PhiMoEAttention

mlp#

The Sparse MoE block.

Type

PhiMoeSparseMoeBlock

post_attention_layernorm#

RMS normalization applied after the attention layer and before the MoE block.

Type

RMSNorm

dropout#

Dropout layer (potentially unused, dropout is often handled within submodules).

Type

nn.Dropout

class easydel.modules.phimoe.modeling_phimoe_flax.PhiMoeForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

PhiMoE model with a Causal Language Modeling head.

This model consists of the base PhiMoE transformer (PhiMoeModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.

config#

Configuration object for the model.

Type

PhiMoeConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core PhiMoE transformer model.

Type

PhiMoeModel

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

class easydel.modules.phimoe.modeling_phimoe_flax.PhiMoeModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base PhiMoE model transformer.

This class represents the core transformer architecture of the PhiMoE model, consisting of an embedding layer, multiple PhiMoeDecoderLayer layers (which include Sparse Mixture of Experts blocks), and a final RMS normalization layer.

config#

Configuration object for the model.

Type

PhiMoeConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[PhiMoeDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

embed_dropout#

Dropout layer applied after embeddings.

Type

nn.Dropout

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

class easydel.modules.phimoe.modeling_phimoe_flax.PhiMoeSparseMoeBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

PhiMoE Sparse Mixture of Experts (MoE) Block.

This module implements the core MoE logic, including the router (gate) and the expert layers. It routes each token to the top-k experts based on the router logits and combines the expert outputs.

config#

Configuration object for the model.

Type

PhiMoeConfig

layer_idx#

Index of the current layer.

Type

int

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

gate#

Linear layer for the router gate.

Type

ParallelLinear

experts#

List of expert MLP modules.

Type

tp.List[PhiMoEBlockSparseTop2MLP]