# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import ClassVar
import jax.lax
from chex import Array
from eformer import common_types
from eformer.escale import apply_logical_sharding
from ejkernel.types import MaskInfo
from flax import nnx as nn
from jax import numpy as jnp
from jax.ad_checkpoint import checkpoint_name
from jaxtyping import Bool, Float, Int
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.modeling_outputs import AttentionLayerOutput, BaseModelOutput, CausalLMOutput, DecoderLayerOutput
from easydel.infra.utils import ACT2FN, auto_remat, block_wise_ffn, get_dot_general_by_bits
from easydel.layers.attention_unified import UnifiedAttention
from easydel.layers.base_modules import BaseCausalLMModule
from easydel.layers.caching import (
RaggedPagesCache,
RaggedPagesCacheView,
RaggedPagesMetadata,
TransformerCache,
TransformerCacheView,
TransformerMetadata,
)
from easydel.layers.linear import ColumnParallelLinear, RowParallelLinear
from easydel.layers.norms import RMSNorm as RMSNorm
from .phi3_configuration import Phi3Config
[docs]class Phi3MLP(nn.Module):
"""Phi3 MLP module.
This module implements the feed-forward network (MLP) used in the Phi-3 model.
It consists of a combined gate and up projection, SiLU activation, and a down projection.
Attributes:
config (Phi3Config): Configuration object for the model.
dtype (jnp.dtype): Data type for computations.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
gate_up_proj (ParallelLinear): Combined linear layer for gate and up projections.
down_proj (ParallelLinear): Linear layer for the down projection.
activation_fn (callable): Activation function (SiLU).
"""
def __init__(
self,
config: Phi3Config,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
layer_idx: int,
):
"""Initializes the Phi3MLP module.
Args:
config (Phi3Config): The configuration object for the Phi-3 model.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
"""
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
column_parallel_linear = functools.partial(
ColumnParallelLinear,
dtype=dtype,
param_dtype=param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
precision=precision,
rngs=rngs,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
row_parallel_linear = functools.partial(
RowParallelLinear,
dtype=dtype,
param_dtype=param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
precision=precision,
rngs=rngs,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
self.gate_up_proj = column_parallel_linear(
config.hidden_size,
2 * config.intermediate_size,
rngs=rngs,
)
self.down_proj = row_parallel_linear(
config.intermediate_size,
config.hidden_size,
rngs=rngs,
)
self.activation_fn = ACT2FN[self.config.hidden_act]
def __call__(
self, hidden_states: Float[Array, "batch seq_len hidden_dim"]
) -> Float[Array, "batch seq_len hidden_dim"]:
"""Forward pass of the Phi3MLP module.
Args:
hidden_states: Input hidden states.
Returns:
Output hidden states after MLP transformation.
"""
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
up_states = self.gate_up_proj(hidden_states)
gate, up_states = jnp.split(up_states, 2, axis=-1)
gate = checkpoint_name(self.activation_fn(gate), "mlp_gate")
up_states = checkpoint_name(up_states * gate, "mlp_up")
hidden_states = checkpoint_name(self.down_proj(up_states), "mlp_down")
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
return checkpoint_name(hidden_states, "mlp_output")
[docs]class Phi3Attention(UnifiedAttention):
"""Phi3 Attention module with fused QKV projection.
This module implements the multi-head attention mechanism used in the Phi-3 model.
It supports Grouped Query Attention (GQA), Rotary Position Embeddings (RoPE), and
sliding window attention. The query, key, and value projections are combined into
a single fused linear layer for efficiency.
Attributes:
config (Phi3Config): Configuration object for the model.
dtype (jnp.dtype): Data type for computations.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
rngs (nn.Rngs): Random number generators.
sliding_window (int): Sliding window size for local attention.
qkv_proj (ColumnParallelLinear): Fused linear layer for query, key, and value projections.
o_proj (RowParallelLinear): Linear layer for the output projection.
attention_performer (FlexibleAttentionModule): Module to perform the core attention computation.
rotary (RoPE): Rotary position embedding module with partial RoPE support.
"""
projection_mapping: ClassVar[dict[str, str]] = {
"output_projection": "o_proj",
"query_key_value_projection": "qkv_proj",
}
def __init__(
self,
config: Phi3Config,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
"""Initializes the Phi3Attention module.
Args:
config (Phi3Config): The configuration object for the Phi-3 model.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.bfloat16.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.bfloat16.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
Raises:
ValueError: If `hidden_size` is not divisible by `num_heads`.
"""
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
layer_idx=layer_idx,
attention_type="standard",
causal=True,
sliding_window=config.sliding_window,
)
[docs] def define_network(
self,
config: Phi3Config,
dtype: jnp.dtype,
param_dtype: jnp.dtype,
precision: jax.lax.PrecisionLike,
rngs: nn.Rngs,
):
"""Override to create fused QKV projection instead of separate Q/K/V.
Args:
config: Model configuration
dtype: Data type for computations
param_dtype: Data type for parameters
precision: JAX precision setting
rngs: Random number generators
"""
qkv_size = config.num_attention_heads * self.head_dim + 2 * config.num_key_value_heads * self.head_dim
self.qkv_proj = ColumnParallelLinear(
config.hidden_size,
qkv_size,
rngs=rngs,
use_bias=False,
dtype=dtype,
param_dtype=param_dtype,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
precision=precision,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
self.o_proj = self._create_o_proj(config, dtype, param_dtype, precision, rngs)
self.attention_performer = self._create_attention_performer(config, rngs)
self.rotary = self._create_rotary(config, dtype)
if hasattr(config, "resid_pdrop") and config.resid_pdrop > 0:
self.resid_dropout = nn.Dropout(rate=config.resid_pdrop, rngs=rngs)
else:
self.resid_dropout = None
def _create_rotary(self, config: Phi3Config, dtype: jnp.dtype):
"""Create rotary position embedding layer with Phi-3 specific configuration.
Phi-3 uses partial RoPE with custom base theta.
Args:
config: Model configuration
dtype: Data type for computations
"""
return config.get_basic_rope(
dtype=dtype,
head_size=self.head_dim,
base=config.rope_theta,
is_neox_style=True,
)
def __call__(
self,
hidden_states: Float[Array, "batch seq_len hidden_dim"],
mask_info: MaskInfo | None,
position_ids: Int[Array, "batch seq_len"],
mode: common_types.RUNTIME_MODE_TYPES, # type:ignore
cache_view: TransformerCacheView | RaggedPagesCacheView | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
output_attentions: bool = False,
frequencies: Float[Array, "seq_len head_dim"] | None = None,
) -> AttentionLayerOutput:
"""Forward pass of the Phi3Attention module.
Uses the parent DecoderAttention implementation with sliding window support.
Args:
hidden_states: Input hidden states.
mask_info: Mask information for attention.
position_ids: Position indices for the tokens.
mode: Runtime mode (train/eval/infer).
cache_view: Cache view for attention KVs.
cache_metadata: Metadata for paged attention.
output_attentions: Whether to return attention weights.
frequencies: Precomputed rotary frequency embeddings.
Returns:
AttentionLayerOutput containing attention output and optional weights.
"""
batch_size, sequence_length = hidden_states.shape[:2]
qkv = checkpoint_name(self.qkv_proj(hidden_states), "attn_qkv")
q_size = self.config.num_attention_heads * self.head_dim
kv_size = self.config.num_key_value_heads * self.head_dim
query_states = qkv[..., :q_size]
key_states = qkv[..., q_size : q_size + kv_size]
value_states = qkv[..., q_size + kv_size :]
query_states = query_states.reshape(
batch_size,
sequence_length,
self.config.num_attention_heads,
self.head_dim,
)
key_states = key_states.reshape(
batch_size,
sequence_length,
self.config.num_key_value_heads,
self.head_dim,
)
value_states = value_states.reshape(
batch_size,
sequence_length,
self.config.num_key_value_heads,
self.head_dim,
)
query_states, key_states, value_states = self.apply_qkv_shardings(query_states, key_states, value_states)
query_states, key_states = self._apply_rotary(query_states, key_states, position_ids, frequencies)
(
key_states,
value_states,
mask_info,
init_attention_bias,
cache_view,
cache_metadata,
) = self.concatenate(
query=query_states,
key=key_states,
value=value_states,
cache_view=cache_view,
cache_metadata=cache_metadata,
mask_info=mask_info,
sliding_window=self.sliding_window,
)
attentions = self.attention_performer.forward(
query_states=query_states,
key_states=key_states,
value_states=value_states,
mode=mode,
bias=None,
cache_metadata=cache_metadata,
cache_view=cache_view,
init_bias=init_attention_bias,
mask_info=mask_info,
causal=True,
sliding_window=self.sliding_window,
)
attn_output = self._merge_heads(attentions.attention_outputs)
attn_output = self.shard_attention_prod(attn_output=attn_output)
attn_output = checkpoint_name(self.output_projection(attn_output), "attn_output")
if self.resid_dropout is not None:
attn_output = self.resid_dropout(attn_output)
return AttentionLayerOutput(
attention_output=attn_output,
attention_weight=attentions.attention_weights if output_attentions else None,
cache_view=cache_view,
)
[docs]class Phi3DecoderLayer(nn.Module):
"""Phi3 Transformer Decoder Layer.
This module represents a single decoder layer in the Phi-3 model,
combining self-attention and MLP sub-layers with residual connections
and RMS normalization.
Attributes:
config (Phi3Config): Configuration object for the model.
dtype (jnp.dtype): Data type for computations.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
rngs (nn.Rngs): Random number generators.
input_layernorm (RMSNorm): RMS normalization applied before the attention layer.
self_attn (Phi3Attention): The self-attention module.
mlp (Phi3MLP): The feed-forward (MLP) module.
post_attention_layernorm (RMSNorm): RMS normalization applied after the attention layer and before the MLP layer.
dropout (nn.Dropout): Dropout layer applied to the residual connections.
"""
def __init__(
self,
config: Phi3Config,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
layer_idx: int,
):
"""Initializes the Phi3DecoderLayer.
Args:
config (Phi3Config): The configuration object for the Phi-3 model.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
"""
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
attn_block = Phi3Attention
mlp_block = Phi3MLP
attn_block, mlp_block = auto_remat(
attn_block,
mlp_block,
policy=config.gradient_checkpointing,
save_names=config.gradient_checkpointing_targets,
exclude_names=config.gradient_checkpointing_targets,
)
self.self_attn = attn_block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
layer_idx=layer_idx,
)
self.mlp = mlp_block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.input_layernorm = RMSNorm(
dim=config.hidden_size,
eps=config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.resid_attn_dropout = nn.Dropout(
self.config.resid_pdrop,
rngs=rngs,
)
self.resid_mlp_dropout = nn.Dropout(
self.config.resid_pdrop,
rngs=rngs,
)
self.post_attention_layernorm = RMSNorm(
dim=self.config.hidden_size,
eps=self.config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
def __call__(
self,
hidden_states: Float[Array, "batch seq_len hidden_dim"],
mask_info: MaskInfo | None,
position_ids: Int[Array, "batch seq_len"],
mode: common_types.RUNTIME_MODE_TYPES, # type:ignore
cache_view: TransformerCacheView | RaggedPagesCacheView | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
output_attentions: bool = False,
frequencies: Float[Array, "seq_len head_dim"] | None = None,
):
"""Forward pass of the Phi3DecoderLayer module.
Args:
hidden_states (chex.Array): Input hidden states.
attention_mask (chex.Array): Mask to apply on the attention scores.
position_ids (chex.Array): Position indices for the tokens. Shape: (batch_size, sequence_length).
causal_mask (tp.Optional[chex.Array | bool]): Causal mask for ensuring autoregressive behavior.
cache_view (tp.Optional[TransformerCacheView | RaggedPagesCacheView]): Cache view for attention KVs.
cache_metadata (tp.Optional[TransformerMetadata | RaggedPagesMetadata]): Metadata for paged attention.
segment_ids (tp.Optional[chex.Array]): Segment IDs for segment-based attention (optional).
output_attentions (bool): Whether to return attention weights. Default is False.
fcm_mask (tp.Optional[chex.Array]): Flash Chunking Mask (FCM) for attention.
frequencies (tp.Optional[chex.Array]): Precomputed rotary frequency embeddings.
Returns:
tp.Tuple[chex.Array, tp.Optional[chex.Array]]:
A tuple containing the output hidden states and optionally the attention weights.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
attn_outputs = self.self_attn(
hidden_states,
mask_info,
position_ids,
mode,
cache_view,
cache_metadata,
output_attentions,
frequencies,
)
hidden_states = checkpoint_name(self.resid_attn_dropout(attn_outputs.attention_output) + residual, "residual")
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.config.use_scan_mlp:
feed_forward_hidden_states = block_wise_ffn(
self.mlp,
hidden_states,
self.config.scan_mlp_chunk_size,
)
else:
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = checkpoint_name(residual + self.resid_mlp_dropout(feed_forward_hidden_states), "residual")
hidden_states = checkpoint_name(hidden_states, "layer_output")
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
return DecoderLayerOutput(
hidden_states=hidden_states,
attention_weight=attn_outputs.attention_weight,
cache_view=attn_outputs.cache_view,
)
[docs]@register_module(TaskType.BASE_MODULE, config=Phi3Config, model_type="phi3")
class Phi3Model(EasyDeLBaseModule):
"""The base Phi-3 model transformer.
This class represents the core transformer architecture of the Phi-3 model,
consisting of an embedding layer, multiple Phi3DecoderLayer layers,
and a final RMS normalization layer.
Attributes:
config (Phi3Config): Configuration object for the model.
dtype (jnp.dtype): Data type for computation.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
rngs (nn.Rngs): Random number generators.
embed_tokens (nn.Embed): Embedding layer for input tokens.
embed_dropout (nn.Dropout): Dropout layer applied after embeddings.
layers (tp.List[Phi3DecoderLayer]): List of decoder layers.
norm (RMSNorm): Final layer normalization.
gradient_checkpointing (EasyDeLGradientCheckPointers): Gradient checkpointing configuration.
"""
def __init__(
self,
config: Phi3Config,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
"""Initializes the Phi3Model.
Args:
config (Phi3Config): The configuration object for the Phi-3 model.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
"""
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
embed_block = auto_remat(
nn.Embed,
policy=config.gradient_checkpointing,
save_names=config.gradient_checkpointing_targets,
exclude_names=config.gradient_checkpointing_targets,
)
self.embed_tokens = embed_block(
config.vocab_size,
config.hidden_size,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.embed_dropout = nn.Dropout(config.embd_pdrop)
self.layers = [
Phi3DecoderLayer(
config=config,
layer_idx=idx,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
for idx in range(self.config.num_hidden_layers)
]
self.norm = RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
)
@functools.cached_property
def frequencies(self):
return self.config.get_basic_frequencies(
head_size=self.config.hidden_size // self.config.num_attention_heads,
base=self.config.rope_theta,
)
def __call__(
self,
input_ids: Int[Array, "batch seq_len"] | None = None,
inputs_embeds: Float[Array, "batch seq_len hidden_dim"] | None = None,
attention_mask: Bool[Array, "batch seq_len"] | None = None,
mask_info: MaskInfo | None = None,
position_ids: Int[Array, "batch seq_len"] | None = None,
mode: common_types.RUNTIME_MODE_TYPES | None = None, # type:ignore
past_key_values: TransformerCache | RaggedPagesCache | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
) -> BaseModelOutput:
"""Forward pass of the Phi3Model.
Args:
input_ids (tp.Optional[chex.Array]): Input token IDs. Shape: (batch_size, sequence_length).
inputs_embeds (tp.Optional[chex.Array]): Input embeddings.
Either `input_ids` or `inputs_embeds` must be provided.
attention_mask (tp.Optional[chex.Array]): Mask to avoid performing attention on padding token indices.
Shape: (batch_size, sequence_length).
position_ids (tp.Optional[chex.Array]): Position indices for the tokens.
Shape: (batch_size, sequence_length).
segment_ids (tp.Optional[chex.Array]): Segment IDs (unused).
output_attentions (tp.Optional[bool]): Whether to return attention weights.
Defaults to `config.output_attentions`.
output_hidden_states (tp.Optional[bool]): Whether to return hidden states for all layers.
Defaults to `config.output_hidden_states`.
past_key_values (tp.Optional[TransformerCache | RaggedPagesCache]):
Precomputed key/value states for attention.
cache_metadata (tp.Optional[TransformerMetadata | RaggedPagesMetadata]): Metadata for paged attention.
Returns:
BaseModelOutput: The model's output.
returns a `BaseModelOutput` object containing `last_hidden_state`, `hidden_states` (optional),
and `attentions` (optional).
Raises:
ValueError: If neither `input_ids` nor `inputs_embeds` is provided.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = checkpoint_name(self.embed_tokens(input_ids.astype("i4")), "embeddings")
sequence_length = inputs_embeds.shape[1]
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
assert sequence_length <= self.config.max_position_embeddings, (
f"Maximum Position Embedding Reached ! "
f"(Excepted <= {self.config.max_position_embeddings} got {sequence_length})"
)
mask_info = MaskInfo.dynamic_init(
mask_info=mask_info,
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
if position_ids is None:
position_ids = mask_info.q_position_ids
if mode is None:
mode = (
common_types.MODE_DECODE
if sequence_length == 1 and past_key_values is not None
else common_types.MODE_TRAIN
)
if past_key_values is None:
past_key_values = TransformerCache.init_empty(len(self.layers))
hidden_states = apply_logical_sharding(
inputs_embeds,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
for idx, block in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = block(
hidden_states=hidden_states,
mask_info=mask_info,
position_ids=position_ids,
mode=mode,
cache_view=past_key_values.views[idx],
cache_metadata=cache_metadata,
output_attentions=output_attentions,
frequencies=self.frequencies,
)
hidden_states = layer_outputs.hidden_states
if output_attentions:
all_attentions += (layer_outputs.attention_weight,)
past_key_values[idx] = layer_outputs.cache_view
hidden_states = self.norm(hidden_states)
hidden_states = checkpoint_name(hidden_states, "model_output")
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
past_key_values=past_key_values,
)
[docs] def get_encoder(self):
"""
Returns the encoder part of the model's graph definition.
Decoder-Only models don't have an encoder.
"""
raise NotImplementedError("This is a decoder-only model and does not have an encoder.")
[docs] def get_decoder(self):
"""
Returns the decoder part of the model's graph definition.
"""
return self
[docs] def get_lm_head(self):
"""
Returns the language model head of the module.
Base Models don't have a Language Model Head.
"""
raise NotImplementedError("The base model does not have a language model head.")
[docs] def get_embedding(self):
"""
Returns the embedding layer of the module.
"""
return self.embed_tokens
[docs]@register_module(TaskType.CAUSAL_LM, config=Phi3Config, model_type="phi3")
class Phi3ForCausalLM(BaseCausalLMModule[Phi3Model, Phi3Config]):
"""Phi-3 model with a Causal Language Modeling head."""
_task_type = TaskType.CAUSAL_LM
_model_type = "phi3"
_config_class = Phi3Config
def __init__(
self,
config: Phi3Config,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
base_model_class=Phi3Model,
base_model_name="model",
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
lm_head_bias=False,
)
def __call__(
self,
input_ids: Int[Array, "batch seq_len"] | None = None,
inputs_embeds: Float[Array, "batch seq_len hidden_dim"] | None = None,
attention_mask: Bool[Array, "batch seq_len"] | None = None,
mask_info: MaskInfo | None = None,
position_ids: Int[Array, "batch seq_len"] | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
mode: common_types.RUNTIME_MODE_TYPES | None = None, # type:ignore
past_key_values: TransformerCache | RaggedPagesCache | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
apply_lm_head: bool = True,
) -> CausalLMOutput:
"""Forward pass of the Phi3ForCausalLM model.
Args:
input_ids (tp.Optional[chex.Array]): Input token IDs. Shape: (batch_size, sequence_length).
inputs_embeds (tp.Optional[chex.Array]): Input embeddings.
Either `input_ids` or `inputs_embeds` must be provided.
attention_mask (tp.Optional[chex.Array]): Mask to avoid performing attention on padding token indices.
Shape: (batch_size, sequence_length).
position_ids (tp.Optional[chex.Array]): Position indices for the tokens.
Shape: (batch_size, sequence_length).
segment_ids (tp.Optional[chex.Array]): Segment IDs (unused).
output_attentions (tp.Optional[bool]): Whether to return attention weights.
Defaults to `config.output_attentions`.
output_hidden_states (tp.Optional[bool]): Whether to return hidden states for all layers.
Defaults to `config.output_hidden_states`.
past_key_values (tp.Optional[TransformerCache | RaggedPagesCache]):
Precomputed key/value states for attention.
cache_metadata (tp.Optional[TransformerMetadata | RaggedPagesMetadata]): Metadata for paged attention.
Returns:
CausalLMOutput: The model's output.
returns a `CausalLMOutput` object containing `logits`, `hidden_states` (optional),
and `attentions` (optional).
"""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
mask_info=mask_info,
position_ids=position_ids,
mode=mode,
past_key_values=past_key_values,
cache_metadata=cache_metadata,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
inputs_embeds=inputs_embeds,
)
hidden_states = outputs.last_hidden_state
lm_logits = None
if apply_lm_head:
lm_logits = checkpoint_name(self.apply_lm_head(hidden_states), "lm_head_output")
return CausalLMOutput(
logits=lm_logits,
hidden_states=outputs.hidden_states,
last_hidden_state=outputs.last_hidden_state,
attentions=outputs.attentions,
past_key_values=outputs.past_key_values,
)
[docs] def get_encoder(self):
"""
Returns the encoder part of the model's graph definition.
Decoder-Only models don't have an encoder.
"""
raise NotImplementedError("This is a decoder-only model and does not have an encoder.")
[docs] def get_decoder(self):
"""
Returns the decoder part of the model's graph definition.
"""
return self.model.get_decoder()
[docs] def get_lm_head(self):
"""
Returns the language model head of the module.
"""
return self.lm_head
[docs] def get_embedding(self):
"""
Returns the embedding layer of the module.
"""
return self.model.get_embedding()