# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import chex
import jax.lax
from chex import Array
from eformer import common_types
from eformer.escale import apply_logical_sharding
from ejkernel.types import MaskInfo
from flax import nnx as nn
from jax import numpy as jnp
from jax.ad_checkpoint import checkpoint_name
from jaxtyping import Bool, Float, Int
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.modeling_outputs import BaseModelOutput, CausalLMOutput
from easydel.infra.utils import ACT2FN, auto_remat, block_wise_ffn, get_dot_general_by_bits
from easydel.layers.attention_unified import UnifiedAttention
from easydel.layers.caching import (
RaggedPagesCache,
RaggedPagesCacheView,
RaggedPagesMetadata,
TransformerCache,
TransformerCacheView,
TransformerMetadata,
)
from easydel.layers.linear import ColumnParallelLinear, RowParallelLinear
from easydel.layers.norms import RMSNorm as RMSNorm
from .phimoe_configuration import PhiMoeConfig
[docs]class PhiMoEBlockSparseTop2MLP(nn.Module):
"""PhiMoE Block Sparse Top-2 MLP module.
This module implements the feed-forward network (MLP) for a single expert
in the PhiMoE model's Mixture of Experts layer. It uses a Gated Linear Unit (GLU)
structure with SiLU activation.
Attributes:
config (PhiMoeConfig): Configuration object for the model.
dtype (jnp.dtype): Data type for computations.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
w1 (ParallelLinear): First linear layer (part of the GLU gate).
w2 (ParallelLinear): Second linear layer (down-projection).
w3 (ParallelLinear): Third linear layer (part of the GLU value).
act_fn (callable): Activation function (SiLU).
"""
def __init__(
self,
config: PhiMoeConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
"""Initializes the PhiMoEBlockSparseTop2MLP module.
Args:
config (PhiMoeConfig): The configuration object for the PhiMoE model.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
"""
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
column_parallel_linear = functools.partial(
ColumnParallelLinear,
dtype=dtype,
param_dtype=param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
precision=precision,
rngs=rngs,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
row_parallel_linear = functools.partial(
RowParallelLinear,
dtype=dtype,
param_dtype=param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
precision=precision,
rngs=rngs,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
ffn_dim = config.intermediate_size
hidden_dim = config.hidden_size
self.w1 = column_parallel_linear(hidden_dim, ffn_dim, rngs=rngs)
self.w2 = row_parallel_linear(ffn_dim, hidden_dim, rngs=rngs)
self.w3 = column_parallel_linear(hidden_dim, ffn_dim, rngs=rngs)
self.act_fn = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_states: Array) -> Array:
"""Forward pass of the expert MLP module.
Args:
hidden_states (Array): Input hidden states for this expert.
Shape: (num_tokens_routed_to_expert, hidden_size).
Returns:
Array: Output hidden states after processing by the expert.
Shape: (num_tokens_routed_to_expert, hidden_size).
"""
gate = checkpoint_name(self.act_fn(self.w1(hidden_states)), "mlp_gate")
up = checkpoint_name(self.w3(hidden_states), "mlp_up")
down = checkpoint_name(self.w2(gate * up), "mlp_down")
return checkpoint_name(down, "mlp_output")
[docs]class PhiMoEAttention(UnifiedAttention):
"""PhiMoE attention powered by UnifiedAttention with optional sharding constraint."""
def __init__(
self,
config: PhiMoeConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
layer_idx: int,
):
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
layer_idx=layer_idx,
attention_type="standard",
causal=True,
sliding_window=config.sliding_window,
)
[docs]class PhiMoeSparseMoeBlock(nn.Module):
"""PhiMoE Sparse Mixture of Experts (MoE) Block.
This module implements the core MoE logic, including the router (gate)
and the expert layers. It routes each token to the top-k experts based
on the router logits and combines the expert outputs.
Attributes:
config (PhiMoeConfig): Configuration object for the model.
layer_idx (int): Index of the current layer.
dtype (jnp.dtype): Data type for computations.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
rngs (nn.Rngs): Random number generators.
gate (ParallelLinear): Linear layer for the router gate.
experts (tp.List[PhiMoEBlockSparseTop2MLP]): List of expert MLP modules.
"""
def __init__(
self,
config: PhiMoeConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
layer_idx: int,
):
"""Initializes the PhiMoeSparseMoeBlock module.
Args:
config (PhiMoeConfig): The configuration object for the PhiMoE model.
layer_idx (int): Index of the current layer.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
"""
self.config = config
self.layer_idx = layer_idx
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
self.router_jitter_noise = config.router_jitter_noise
self.input_jitter_noise = config.input_jitter_noise
self.gate = ColumnParallelLinear(
self.config.hidden_size,
self.config.num_local_experts,
use_bias=False,
rngs=rngs,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
kernel_init=nn.initializers.normal(),
)
self.experts = [
PhiMoEBlockSparseTop2MLP(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
for i in range(self.config.num_local_experts)
]
def __call__(
self,
hidden_states: chex.Array,
deterministic: bool = False,
) -> tuple[chex.Array, chex.Array]:
"""Forward pass of the Sparse MoE block.
Args:
hidden_states (chex.Array): Input hidden states. Shape: (batch_size * sequence_length, hidden_size).
deterministic (bool): If True, disables dropout/jitter for deterministic behavior. Defaults to False.
Returns:
tp.Tuple[chex.Array, chex.Array]:
- final_hidden_states: Output hidden states after MoE processing.
Shape: (batch_size * sequence_length, hidden_size).
- router_logits: Logits computed by the router gate.
Shape: (batch_size * sequence_length, num_local_experts).
"""
_batch_size, _sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_dim)
router_logits = self.gate(hidden_states).astype( # no reshaping is needed
jnp.promote_types(self.dtype, jnp.float32)
)
routing_weights, selected_experts = jax.lax.top_k(router_logits, k=self.config.num_experts_per_tok)
routing_weights = jax.nn.softmax(routing_weights.astype(jnp.promote_types(self.dtype, jnp.float32)), axis=-1)
if not deterministic and self.input_jitter_noise > 0:
final_hidden_state = jax.nn.initializers.uniform(
1.0 - self.input_jitter_noise,
1.0 + self.input_jitter_noise,
)(self.make_rng(), hidden_states.shape, hidden_states.dtype)
else:
final_hidden_state = jnp.zeros_like(hidden_states)
for index in range(self.config.num_local_experts):
expert_layer_output = (
block_wise_ffn(
self.experts[index],
hidden_states,
self.config.scan_mlp_chunk_size,
)
if self.config.use_scan_mlp
else self.experts[index](hidden_states)
)
expert_layer_output_exp = (
jnp.sum(jnp.multiply(selected_experts == index, routing_weights), axis=-1)[:, :, None]
* expert_layer_output
)
final_hidden_state += checkpoint_name(expert_layer_output_exp, "moe_expert_output")
return final_hidden_state, checkpoint_name(router_logits, "moe_router_logits")
[docs]class PhiMoeDecoderLayer(nn.Module):
"""PhiMoE Transformer Decoder Layer.
This module represents a single decoder layer in the PhiMoE model.
It combines self-attention and a Sparse Mixture of Experts (MoE) block
(or a standard MLP if not an MoE layer) with residual connections and
RMS normalization.
Attributes:
config (PhiMoeConfig): Configuration object for the model.
layer_idx (int): Index of the current layer.
dtype (jnp.dtype): Data type for computations.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
rngs (nn.Rngs): Random number generators.
input_layernorm (RMSNorm): RMS normalization applied before the attention layer.
self_attn (PhiMoEAttention): The self-attention module.
mlp (PhiMoeSparseMoeBlock): The Sparse MoE block.
post_attention_layernorm (RMSNorm): RMS normalization applied after the attention layer and before the MoE block.
dropout (nn.Dropout): Dropout layer (potentially unused, dropout is often handled within submodules).
"""
def __init__(
self,
config: PhiMoeConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
layer_idx: int,
):
"""Initializes the PhiMoeDecoderLayer.
Args:
config (PhiMoeConfig): The configuration object for the PhiMoE model.
layer_idx (int): Index of the current layer.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
"""
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
attn_block = PhiMoEAttention
mlp_block = PhiMoeSparseMoeBlock
attn_block, mlp_block = auto_remat(
attn_block,
mlp_block,
policy=config.gradient_checkpointing,
save_names=config.gradient_checkpointing_targets,
exclude_names=config.gradient_checkpointing_targets,
)
self.self_attn = attn_block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
layer_idx=layer_idx,
)
self.block_sparse_moe = mlp_block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
layer_idx=layer_idx,
)
self.input_layernorm = nn.LayerNorm(
config.hidden_size,
epsilon=config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
use_bias=True,
rngs=rngs,
)
self.post_attention_layernorm = nn.LayerNorm(
config.hidden_size,
epsilon=config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
use_bias=True,
rngs=rngs,
)
def __call__(
self,
hidden_states: Float[Array, "batch seq_len hidden_dim"],
mask_info: MaskInfo,
position_ids: Int[Array, "batch seq_len"],
mode: common_types.RUNTIME_MODE_TYPES, # type:ignore
cache_view: TransformerCacheView | RaggedPagesCacheView | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
output_attentions: bool = False,
output_router_logits: bool = False,
frequencies: Float[Array, "seq_len head_dim"] | None = None,
):
"""Forward pass of the PhiMoeDecoderLayer module.
Args:
hidden_states (chex.Array): Input hidden states.
attention_mask (chex.Array): Mask to apply on the attention scores.
position_ids (chex.Array): Position indices for the tokens. Shape: (batch_size, sequence_length).
causal_mask (tp.Optional[chex.Array | bool]): Causal mask for ensuring autoregressive behavior.
segment_ids (tp.Optional[chex.Array]): Segment IDs for segment-based attention (optional).
cache_view (tp.Optional[TransformerCacheView | RaggedPagesCacheView]): Cache view for attention KVs.
cache_metadata (tp.Optional[TransformerMetadata | RaggedPagesMetadata]): Metadata for paged attention.
output_attentions (bool): Whether to return attention weights. Default is False.
output_router_logits (bool): Whether to return router logits from the MoE layer. Default is False.
fcm_mask (tp.Optional[chex.Array]): Flash Chunking Mask (FCM) for attention.
frequencies (tp.Optional[chex.Array]): Precomputed rotary frequency embeddings.
Returns:
tp.Tuple[chex.Array, tp.Optional[chex.Array], tp.Optional[chex.Array]]:
A tuple containing:
- hidden_states: Output hidden states after the decoder layer.
- self_attn_weights: Attention weights (if `output_attentions` is True).
- router_logits: Router logits from the MoE layer (if `output_router_logits` is True).
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
attn_outputs = self.self_attn(
hidden_states,
mask_info,
position_ids,
mode,
cache_view,
cache_metadata,
output_attentions,
frequencies,
)
hidden_states, self_attn_weights = (
attn_outputs.attention_output,
attn_outputs.attention_weight,
)
hidden_states = checkpoint_name(residual + hidden_states, "residual")
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = checkpoint_name(residual + hidden_states, "residual")
hidden_states = checkpoint_name(hidden_states, "layer_output")
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if output_router_logits:
outputs += (router_logits,)
return outputs
[docs]@register_module(TaskType.BASE_MODULE, config=PhiMoeConfig, model_type="phimoe")
class PhiMoeModel(EasyDeLBaseModule):
"""The base PhiMoE model transformer.
This class represents the core transformer architecture of the PhiMoE model,
consisting of an embedding layer, multiple PhiMoeDecoderLayer layers
(which include Sparse Mixture of Experts blocks), and a final RMS normalization layer.
Attributes:
config (PhiMoeConfig): Configuration object for the model.
dtype (jnp.dtype): Data type for computation.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
rngs (nn.Rngs): Random number generators.
embed_tokens (nn.Embed): Embedding layer for input tokens.
layers (tp.List[PhiMoeDecoderLayer]): List of decoder layers.
norm (RMSNorm): Final layer normalization.
embed_dropout (nn.Dropout): Dropout layer applied after embeddings.
gradient_checkpointing (EasyDeLGradientCheckPointers): Gradient checkpointing configuration.
"""
def __init__(
self,
config: PhiMoeConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
"""Initializes the PhiMoeModel.
Args:
config (PhiMoeConfig): The configuration object for the PhiMoE model.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
"""
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
embed_block = auto_remat(
nn.Embed,
policy=config.gradient_checkpointing,
save_names=config.gradient_checkpointing_targets,
exclude_names=config.gradient_checkpointing_targets,
)
self.embed_tokens = embed_block(
config.vocab_size,
config.hidden_size,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.embed_dropout = nn.Dropout(config.embd_pdrop)
self.layers = [
PhiMoeDecoderLayer(
config=config,
layer_idx=idx,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
for idx in range(self.config.num_hidden_layers)
]
self.norm = nn.LayerNorm(
config.hidden_size,
epsilon=config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
use_bias=True,
rngs=rngs,
)
def __call__(
self,
input_ids: Int[Array, "batch seq_len"] | None = None,
inputs_embeds: Float[Array, "batch seq_len hidden_dim"] | None = None,
attention_mask: Bool[Array, "batch seq_len"] | None = None,
mask_info: MaskInfo | None = None,
position_ids: Int[Array, "batch seq_len"] | None = None,
mode: common_types.RUNTIME_MODE_TYPES | None = None, # type:ignore
past_key_values: TransformerCache | RaggedPagesCache | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
) -> BaseModelOutput:
"""Forward pass of the PhiMoeModel.
Args:
input_ids (tp.Optional[chex.Array]): Input token IDs. Shape: (batch_size, sequence_length).
inputs_embeds (tp.Optional[chex.Array]): Input embeddings.
Either `input_ids` or `inputs_embeds` must be provided.
attention_mask (tp.Optional[chex.Array]): Mask to avoid performing attention on padding token indices.
Shape: (batch_size, sequence_length).
position_ids (tp.Optional[chex.Array]): Position indices for the tokens.
Shape: (batch_size, sequence_length).
segment_ids (tp.Optional[chex.Array]): Segment IDs (unused).
output_attentions (tp.Optional[bool]): Whether to return attention weights.
Defaults to `config.output_attentions`.
output_hidden_states (tp.Optional[bool]): Whether to return hidden states for all layers.
Defaults to `config.output_hidden_states`.
past_key_values (tp.Optional[TransformerCache | RaggedPagesCache]):
Precomputed key/value states for attention.
cache_metadata (tp.Optional[TransformerMetadata | RaggedPagesMetadata]): Metadata for paged attention.
Returns:
BaseModelOutput: The model's output.
returns a `BaseModelOutput` object containing `last_hidden_state`, `hidden_states` (optional),
`attentions` (optional), and `router_logits` (optional).
Raises:
ValueError: If neither `input_ids` nor `inputs_embeds` is provided.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = checkpoint_name(self.embed_tokens(input_ids.astype("i4")), "embeddings")
sequence_length = inputs_embeds.shape[1]
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
assert sequence_length <= self.config.max_position_embeddings, (
f"Maximum Position Embedding Reached ! "
f"(Excepted <= {self.config.max_position_embeddings} got {sequence_length})"
)
mask_info = MaskInfo.dynamic_init(
mask_info=mask_info,
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
if position_ids is None:
position_ids = mask_info.q_position_ids
if mode is None:
mode = (
common_types.MODE_DECODE
if sequence_length == 1 and past_key_values is not None
else common_types.MODE_TRAIN
)
if past_key_values is None:
past_key_values = TransformerCache.init_empty(len(self.layers))
hidden_states = apply_logical_sharding(
inputs_embeds,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
for idx, block in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = block(
hidden_states=hidden_states,
mask_info=mask_info,
position_ids=position_ids,
mode=mode,
cache_view=past_key_values.views[idx],
cache_metadata=cache_metadata,
output_attentions=output_attentions,
frequencies=self.frequencies,
)
hidden_states = layer_outputs.hidden_states
if output_attentions:
all_attentions += (layer_outputs.attention_weight,)
past_key_values[idx] = layer_outputs.cache_view
hidden_states = self.norm(hidden_states)
hidden_states = checkpoint_name(hidden_states, "model_output")
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
past_key_values=past_key_values,
)
[docs] def get_encoder(self):
"""
Returns the encoder part of the model's graph definition.
Decoder-Only models don't have an encoder.
"""
raise NotImplementedError("This is a decoder-only model and does not have an encoder.")
[docs] def get_decoder(self):
"""
Returns the decoder part of the model's graph definition.
"""
return self
[docs] def get_lm_head(self):
"""
Returns the language model head of the module.
Base Models don't have a Language Model Head.
"""
raise NotImplementedError("The base model does not have a language model head.")
[docs] def get_embedding(self):
"""
Returns the embedding layer of the module.
"""
return self.embed_tokens
[docs]@register_module(TaskType.CAUSAL_LM, config=PhiMoeConfig, model_type="phimoe")
class PhiMoeForCausalLM(EasyDeLBaseModule):
"""PhiMoE model with a Causal Language Modeling head.
This model consists of the base PhiMoE transformer (`PhiMoeModel`) followed by a
linear layer (`lm_head`) that projects the transformer's output hidden states
to the vocabulary size, producing logits for next token prediction.
Optionally, the input token embeddings can be tied to the output projection layer.
Attributes:
config (PhiMoeConfig): Configuration object for the model.
dtype (jnp.dtype): Data type for computation.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
rngs (nn.Rngs): Random number generators.
model (PhiMoeModel): The core PhiMoE transformer model.
lm_head (ParallelLinear): The linear layer for projecting hidden states to vocabulary logits.
"""
def __init__(
self,
config: PhiMoeConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
"""Initializes the PhiMoeForCausalLM model.
Args:
config (PhiMoeConfig): The configuration object for the PhiMoE model.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
"""
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.model = PhiMoeModel(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.vocab_size = self.config.vocab_size
lm_head_block = ColumnParallelLinear
lm_head_block = auto_remat(
lm_head_block,
policy=config.gradient_checkpointing,
save_names=config.gradient_checkpointing_targets,
exclude_names=config.gradient_checkpointing_targets,
)
self.lm_head = lm_head_block(
config.hidden_size,
config.vocab_size,
use_bias=config.lm_head_bias,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
def __call__(
self,
input_ids: Int[Array, "batch seq_len"] | None = None,
inputs_embeds: Float[Array, "batch seq_len hidden_dim"] | None = None,
attention_mask: Bool[Array, "batch seq_len"] | None = None,
mask_info: MaskInfo | None = None,
position_ids: Int[Array, "batch seq_len"] | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
mode: common_types.RUNTIME_MODE_TYPES | None = None, # type:ignore
past_key_values: TransformerCache | RaggedPagesCache | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
apply_lm_head: bool = True,
) -> CausalLMOutput:
"""Forward pass of the PhiMoeForCausalLM model.
Args:
input_ids (tp.Optional[chex.Array]): Input token IDs. Shape: (batch_size, sequence_length).
inputs_embeds (tp.Optional[chex.Array]): Input embeddings.
Either `input_ids` or `inputs_embeds` must be provided.
attention_mask (tp.Optional[chex.Array]): Mask to avoid performing attention on padding token indices.
Shape: (batch_size, sequence_length).
position_ids (tp.Optional[chex.Array]): Position indices for the tokens.
Shape: (batch_size, sequence_length).
segment_ids (tp.Optional[chex.Array]): Segment IDs (unused).
output_attentions (tp.Optional[bool]): Whether to return attention weights.
Defaults to `config.output_attentions`.
output_hidden_states (tp.Optional[bool]): Whether to return hidden states for all layers.
Defaults to `config.output_hidden_states`.
past_key_values (tp.Optional[TransformerCache | RaggedPagesCache]):
Precomputed key/value states for attention.
cache_metadata (tp.Optional[TransformerMetadata | RaggedPagesMetadata]): Metadata for paged attention.
Returns:
CausalLMOutput: The model's output.
returns a `CausalLMOutput` object containing `logits`, `hidden_states` (optional),
`attentions` (optional), and `router_logits` (optional).
"""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
mask_info=mask_info,
position_ids=position_ids,
mode=mode,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_metadata=cache_metadata,
inputs_embeds=inputs_embeds,
)
hidden_states = outputs.last_hidden_state
lm_logits = None
if apply_lm_head:
lm_logits = checkpoint_name(self.apply_lm_head(hidden_states), "lm_head_output")
return CausalLMOutput(
logits=lm_logits,
hidden_states=outputs.hidden_states,
last_hidden_state=outputs.last_hidden_state,
attentions=outputs.attentions,
past_key_values=outputs.past_key_values,
)
[docs] def get_encoder(self):
"""
Returns the encoder part of the model's graph definition.
Decoder-Only models don't have an encoder.
"""
raise NotImplementedError("This is a decoder-only model and does not have an encoder.")
[docs] def get_decoder(self):
"""
Returns the decoder part of the model's graph definition.
"""
return self.model.get_decoder()
[docs] def get_lm_head(self):
"""
Returns the language model head of the module.
"""
return self.lm_head
[docs] def get_embedding(self):
"""
Returns the embedding layer of the module.
"""
return self.model.get_embedding()