# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import typing as tp
import chex
import jax.lax
from chex import Array
from eformer import common_types
from eformer.escale import apply_logical_sharding
from flax import nnx as nn
from jax import numpy as jnp
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.modeling_outputs import (
AttentionLayerOutput,
BaseModelOutput,
CausalLMOutput,
DecoderLayerOutput,
)
from easydel.infra.utils import (
ACT2FN,
auto_remat,
block_wise_ffn,
get_dot_general_by_bits,
)
from easydel.layers.attention import AttentionModule, FlexibleAttentionModule
from easydel.layers.caching import (
PagedAttentionCache,
PagedAttentionCacheView,
PagedAttentionMetadata,
TransformerCache,
TransformerCacheView,
TransformerMetadata,
)
from easydel.layers.linear import ParallelLinear
from easydel.layers.norms import RMSNorm as RMSNorm
from .phi3_configuration import Phi3Config
[docs]class Phi3MLP(nn.Module):
"""Phi3 MLP module.
This module implements the feed-forward network (MLP) used in the Phi-3 model.
It consists of a combined gate and up projection, SiLU activation, and a down projection.
Attributes:
config (Phi3Config): Configuration object for the model.
dtype (jnp.dtype): Data type for computations.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
gate_up_proj (ParallelLinear): Combined linear layer for gate and up projections.
down_proj (ParallelLinear): Linear layer for the down projection.
activation_fn (callable): Activation function (SiLU).
"""
def __init__(
self,
config: Phi3Config,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
"""Initializes the Phi3MLP module.
Args:
config (Phi3Config): The configuration object for the Phi-3 model.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
"""
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
linear_class = functools.partial(
ParallelLinear,
dtype=dtype,
param_dtype=param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
precision=precision,
rngs=rngs,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
self.gate_up_proj = linear_class(
config.hidden_size,
2 * config.intermediate_size,
rngs=rngs,
)
self.down_proj = linear_class(
config.intermediate_size,
config.hidden_size,
rngs=rngs,
)
self.activation_fn = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_states: Array) -> Array:
"""Forward pass of the Phi3MLP module.
Args:
hidden_states (Array): Input hidden states. Shape: (batch_size, sequence_length, hidden_size).
Returns:
Array: Output hidden states after MLP transformation. Shape: (batch_size, sequence_length, hidden_size).
"""
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
up_states = self.gate_up_proj(hidden_states)
gate, up_states = jnp.split(up_states, 2, axis=-1)
up_states = up_states * self.activation_fn(gate)
hidden_states = self.down_proj(up_states)
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
return hidden_states
[docs]class Phi3Attention(AttentionModule):
"""Phi3 Attention module.
This module implements the multi-head attention mechanism used in the Phi-3 model.
It supports Grouped Query Attention (GQA) and Rotary Position Embeddings (RoPE).
The query, key, and value projections are combined into a single linear layer.
Attributes:
config (Phi3Config): Configuration object for the model.
dtype (jnp.dtype): Data type for computations.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
rngs (nn.Rngs): Random number generators.
attention_dropout (float): Dropout probability for attention scores.
hidden_size (int): Dimensionality of the hidden states.
num_heads (int): Number of attention query heads.
head_dim (int): Dimensionality of each attention head.
num_key_value_heads (int): Number of attention key/value heads (for GQA).
num_key_value_groups (int): Number of query head groups for each key/value head.
max_position_embeddings (int): Maximum sequence length supported by RoPE.
original_max_position_embeddings (int): Original max sequence length for RoPE scaling.
is_causal (bool): Whether the attention is causal (always True for this implementation).
o_proj (ParallelLinear): Linear layer for the output projection.
qkv_proj (ParallelLinear): Combined linear layer for query, key, and value projections.
attention_performer (FlexibleAttentionModule): Module to perform the core attention computation.
rotary (RoPE): Rotary position embedding module.
"""
def __init__(
self,
config: Phi3Config,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
"""Initializes the Phi3Attention module.
Args:
config (Phi3Config): The configuration object for the Phi-3 model.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
Raises:
ValueError: If `hidden_size` is not divisible by `num_heads`.
"""
super().__init__(config=config)
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.rngs = rngs
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.original_max_position_embeddings = config.original_max_position_embeddings
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
linear_class = functools.partial(
ParallelLinear,
use_bias=False,
precision=precision,
dtype=dtype,
param_dtype=param_dtype,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
**get_dot_general_by_bits(config.bits, config.easy_method),
)
op_size = self.num_heads * self.head_dim + 2 * (
self.num_key_value_heads * self.head_dim
)
self.o_proj = linear_class(
self.num_heads * self.head_dim,
self.hidden_size,
rngs=rngs,
)
self.qkv_proj = linear_class(
self.hidden_size,
op_size,
rngs=rngs,
)
self.attention_performer = FlexibleAttentionModule(
base_config=config,
softmax_scale=self.head_dim**-0.5,
dropout_prob=config.attention_dropout,
)
self.rotary = self.config.get_basic_rope(
self.dtype,
head_size=self.head_dim,
base=config.rope_theta,
is_neox_style=True,
)
def __call__(
self,
hidden_states: chex.Array,
attention_mask: chex.Array,
position_ids: chex.Array,
causal_mask: tp.Optional[chex.Array | bool],
mode: common_types.RUNTIME_MODE_TYPES, # type:ignore
cache_view: tp.Optional[TransformerCacheView | PagedAttentionCacheView] = None,
cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None,
segment_ids: tp.Optional[chex.Array] = None,
output_attentions: bool = False,
fcm_mask: tp.Optional[chex.Array] = None,
frequencies: tp.Optional[chex.Array] = None,
):
"""
Forward pass of the Phi3Attention module.
Args:
hidden_states (chex.Array): Input hidden states. Shape: (batch_size, sequence_length, hidden_size).
attention_mask (chex.Array): Mask to apply on the attention scores. Shape: (batch_size, 1, query_length, key_length).
position_ids (chex.Array): Position indices for the tokens. Shape: (batch_size, sequence_length).
causal_mask (tp.Optional[chex.Array | bool]): Causal mask for ensuring autoregressive behavior.
cache_view (tp.Optional[TransformerCacheView | PagedAttentionCacheView]): Cache view for attention KVs.
cache_metadata (tp.Optional[TransformerMetadata | PagedAttentionMetadata]): Metadata for paged attention.
segment_ids (tp.Optional[chex.Array]): Segment IDs for segment-based attention (optional).
output_attentions (bool): Whether to return attention weights. Default is False.
fcm_mask (tp.Optional[chex.Array]): Flash Chunking Mask (FCM) for attention.
frequencies (tp.Optional[chex.Array]): Precomputed rotary frequency embeddings.
Returns:
tp.Union[tp.Tuple[chex.Array, chex.Array], tp.Tuple[chex.Array]]:
A tuple containing the attention output hidden states. If `output_attentions` is True,
it also includes the attention weights.
"""
batch_size, sequence_length = hidden_states.shape[:2]
qkv = self.qkv_proj(hidden_states)
query_pos = self.num_heads * self.head_dim
query_states = qkv[..., :query_pos]
key_states = qkv[
..., query_pos : query_pos + self.num_key_value_heads * self.head_dim
]
value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
query_states = query_states.reshape(
batch_size,
sequence_length,
self.config.num_attention_heads,
self.head_dim,
)
key_states = key_states.reshape(
batch_size,
sequence_length,
self.config.num_key_value_heads,
self.head_dim,
)
value_states = value_states.reshape(
batch_size,
sequence_length,
self.config.num_key_value_heads,
self.head_dim,
)
(
query_states,
key_states,
value_states,
) = self.apply_qkv_shardings(query_states, key_states, value_states)
query_states, key_states = self.rotary(
positions=position_ids,
query=query_states,
key=key_states,
frequencies=frequencies,
)
(
key_states,
value_states,
attention_mask,
init_attention_bias,
cache_view,
) = self.concatenate(
query=query_states,
key=key_states,
value=value_states,
cache_view=cache_view,
attention_mask=attention_mask,
causal_mask=causal_mask,
fcm_mask=fcm_mask,
)
attentions = self.attention_performer.forward(
query_states=query_states,
key_states=key_states,
value_states=value_states,
mode=mode,
bias=None,
cache_metadata=cache_metadata,
cache_view=cache_view,
init_bias=init_attention_bias,
attention_mask=attention_mask,
segment_ids=segment_ids,
causal=True,
dropout_rng=self.rngs.params(),
)
attn_output = self.shard_attention_prod(
self._merge_heads(attentions.attention_outputs)
)
attn_output = self.o_proj(attn_output)
return AttentionLayerOutput(
attention_output=attn_output,
attention_weight=attentions.attention_weights if output_attentions else None,
cache_view=cache_view,
)
[docs]class Phi3DecoderLayer(nn.Module):
"""Phi3 Transformer Decoder Layer.
This module represents a single decoder layer in the Phi-3 model,
combining self-attention and MLP sub-layers with residual connections
and RMS normalization.
Attributes:
config (Phi3Config): Configuration object for the model.
dtype (jnp.dtype): Data type for computations.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
rngs (nn.Rngs): Random number generators.
input_layernorm (RMSNorm): RMS normalization applied before the attention layer.
self_attn (Phi3Attention): The self-attention module.
mlp (Phi3MLP): The feed-forward (MLP) module.
post_attention_layernorm (RMSNorm): RMS normalization applied after the attention layer and before the MLP layer.
dropout (nn.Dropout): Dropout layer applied to the residual connections.
"""
def __init__(
self,
config: Phi3Config,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
"""Initializes the Phi3DecoderLayer.
Args:
config (Phi3Config): The configuration object for the Phi-3 model.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
"""
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
attn_block = Phi3Attention
mlp_block = Phi3MLP
attn_block, mlp_block = auto_remat(
attn_block,
mlp_block,
policy=config.gradient_checkpointing,
)
self.self_attn = attn_block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.mlp = mlp_block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.input_layernorm = RMSNorm(
dim=config.hidden_size,
eps=config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.resid_attn_dropout = nn.Dropout(
self.config.resid_pdrop,
rngs=rngs,
)
self.resid_mlp_dropout = nn.Dropout(
self.config.resid_pdrop,
rngs=rngs,
)
self.post_attention_layernorm = RMSNorm(
dim=self.config.hidden_size,
eps=self.config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
def __call__(
self,
hidden_states: chex.Array,
attention_mask: chex.Array,
position_ids: chex.Array,
causal_mask: tp.Optional[chex.Array | bool],
mode: common_types.RUNTIME_MODE_TYPES, # type:ignore
cache_view: tp.Optional[TransformerCacheView | PagedAttentionCacheView] = None,
cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None,
segment_ids: tp.Optional[chex.Array] = None,
output_attentions: bool = False,
fcm_mask: tp.Optional[chex.Array] = None,
frequencies: tp.Optional[chex.Array] = None,
):
"""Forward pass of the Phi3DecoderLayer module.
Args:
hidden_states (chex.Array): Input hidden states. Shape: (batch_size, sequence_length, hidden_size).
attention_mask (chex.Array): Mask to apply on the attention scores. Shape: (batch_size, 1, query_length, key_length).
position_ids (chex.Array): Position indices for the tokens. Shape: (batch_size, sequence_length).
causal_mask (tp.Optional[chex.Array | bool]): Causal mask for ensuring autoregressive behavior.
cache_view (tp.Optional[TransformerCacheView | PagedAttentionCacheView]): Cache view for attention KVs.
cache_metadata (tp.Optional[TransformerMetadata | PagedAttentionMetadata]): Metadata for paged attention.
segment_ids (tp.Optional[chex.Array]): Segment IDs for segment-based attention (optional).
output_attentions (bool): Whether to return attention weights. Default is False.
fcm_mask (tp.Optional[chex.Array]): Flash Chunking Mask (FCM) for attention.
frequencies (tp.Optional[chex.Array]): Precomputed rotary frequency embeddings.
Returns:
tp.Tuple[chex.Array, tp.Optional[chex.Array]]:
A tuple containing the output hidden states and optionally the attention weights.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
attn_outputs = self.self_attn(
hidden_states,
attention_mask,
position_ids,
causal_mask,
mode,
cache_view,
cache_metadata,
segment_ids,
output_attentions,
fcm_mask,
frequencies,
)
hidden_states = self.resid_attn_dropout(attn_outputs.attention_output) + residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.config.use_scan_mlp:
feed_forward_hidden_states = block_wise_ffn(
self.mlp,
hidden_states,
self.config.scan_mlp_chunk_size,
)
else:
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.resid_mlp_dropout(feed_forward_hidden_states)
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
return DecoderLayerOutput(
hidden_states=hidden_states,
attention_weight=attn_outputs.attention_weight,
cache_view=attn_outputs.cache_view,
)
[docs]@register_module(
TaskType.BASE_MODULE,
config=Phi3Config,
model_type="phi3",
)
class Phi3Model(EasyDeLBaseModule):
"""The base Phi-3 model transformer.
This class represents the core transformer architecture of the Phi-3 model,
consisting of an embedding layer, multiple Phi3DecoderLayer layers,
and a final RMS normalization layer.
Attributes:
config (Phi3Config): Configuration object for the model.
dtype (jnp.dtype): Data type for computation.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
rngs (nn.Rngs): Random number generators.
embed_tokens (nn.Embed): Embedding layer for input tokens.
embed_dropout (nn.Dropout): Dropout layer applied after embeddings.
layers (tp.List[Phi3DecoderLayer]): List of decoder layers.
norm (RMSNorm): Final layer normalization.
gradient_checkpointing (EasyDeLGradientCheckPointers): Gradient checkpointing configuration.
"""
def __init__(
self,
config: Phi3Config,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
"""Initializes the Phi3Model.
Args:
config (Phi3Config): The configuration object for the Phi-3 model.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
"""
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embed(
config.vocab_size,
config.hidden_size,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.embed_dropout = nn.Dropout(config.embd_pdrop)
self.layers = [
Phi3DecoderLayer(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
for idx in range(self.config.num_hidden_layers)
]
self.norm = RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
)
@functools.cached_property
def frequencies(self):
return self.config.get_basic_frequencies(
head_size=self.config.hidden_size // self.config.num_attention_heads,
base=self.config.rope_theta,
)
def __call__(
self,
input_ids: tp.Optional[chex.Array] = None,
inputs_embeds: tp.Optional[chex.Array] = None,
attention_mask: tp.Optional[chex.Array] = None,
position_ids: tp.Optional[chex.Array] = None,
segment_ids: tp.Optional[chex.Array] = None,
output_attentions: tp.Optional[bool] = None,
output_hidden_states: tp.Optional[bool] = None,
mode: tp.Optional[common_types.RUNTIME_MODE_TYPES] = None, # type:ignore
past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None,
cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None,
) -> BaseModelOutput:
"""Forward pass of the Phi3Model.
Args:
input_ids (tp.Optional[chex.Array]): Input token IDs. Shape: (batch_size, sequence_length).
inputs_embeds (tp.Optional[chex.Array]): Input embeddings. Shape: (batch_size, sequence_length, hidden_size).
Either `input_ids` or `inputs_embeds` must be provided.
attention_mask (tp.Optional[chex.Array]): Mask to avoid performing attention on padding token indices.
Shape: (batch_size, sequence_length).
position_ids (tp.Optional[chex.Array]): Position indices for the tokens.
Shape: (batch_size, sequence_length).
segment_ids (tp.Optional[chex.Array]): Segment IDs (unused).
output_attentions (tp.Optional[bool]): Whether to return attention weights. Defaults to `config.output_attentions`.
output_hidden_states (tp.Optional[bool]): Whether to return hidden states for all layers.
Defaults to `config.output_hidden_states`.
past_key_values (tp.Optional[TransformerCache | PagedAttentionCache]): Precomputed key/value states for attention.
cache_metadata (tp.Optional[TransformerMetadata | PagedAttentionMetadata]): Metadata for paged attention.
Returns:
BaseModelOutput: The model's output.
returns a `BaseModelOutput` object containing `last_hidden_state`, `hidden_states` (optional),
and `attentions` (optional).
Raises:
ValueError: If neither `input_ids` nor `inputs_embeds` is provided.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids.astype("i4"))
batch_size, sequence_length, _ = inputs_embeds.shape
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
assert sequence_length <= self.config.max_position_embeddings, (
f"Maximum Position Embedding Reached ! (Excepted <= {self.config.max_position_embeddings} got {sequence_length})"
)
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length), "b1")
else:
if attention_mask.dtype != jnp.bool:
attention_mask = jnp.astype(attention_mask == 1, "b1")
if position_ids is None:
position_ids = jnp.broadcast_to(
jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
(batch_size, sequence_length),
).astype(jnp.int32)
if attention_mask.ndim == 2:
attention_mask = jnp.expand_dims(attention_mask, (1, 2))
if mode is None:
mode = (
common_types.MODE_DECODE
if sequence_length == 1 and past_key_values is not None
else common_types.MODE_TRAIN
)
if past_key_values is None:
past_key_values = TransformerCache.init_empty(len(self.layers))
hidden_states = apply_logical_sharding(
inputs_embeds,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
for idx, block in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = block(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
mode=mode,
cache_view=past_key_values.views[idx],
cache_metadata=cache_metadata,
causal_mask=self.causal_mask,
output_attentions=output_attentions,
segment_ids=segment_ids,
frequencies=self.frequencies,
)
hidden_states = layer_outputs.hidden_states
if output_attentions:
all_attentions += (layer_outputs.attention_weight,)
past_key_values[idx] = layer_outputs.cache_view
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
past_key_values=past_key_values,
)
[docs]@register_module(
TaskType.CAUSAL_LM,
config=Phi3Config,
model_type="phi3",
)
class Phi3ForCausalLM(EasyDeLBaseModule):
"""Phi-3 model with a Causal Language Modeling head.
This model consists of the base Phi-3 transformer (`Phi3Model`) followed by a
linear layer (`lm_head`) that projects the transformer's output hidden states
to the vocabulary size, producing logits for next token prediction.
Optionally, the input token embeddings can be tied to the output projection layer.
Attributes:
config (Phi3Config): Configuration object for the model.
dtype (jnp.dtype): Data type for computation.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
rngs (nn.Rngs): Random number generators.
model (Phi3Model): The core Phi-3 transformer model.
lm_head (ParallelLinear): The linear layer for projecting hidden states to vocabulary logits.
"""
def __init__(
self,
config: Phi3Config,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
"""Initializes the Phi3ForCausalLM model.
Args:
config (Phi3Config): The configuration object for the Phi-3 model.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
"""
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.model = Phi3Model(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.vocab_size = self.config.vocab_size
self.lm_head = ParallelLinear(
config.hidden_size,
config.vocab_size,
use_bias=False,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
def __call__(
self,
input_ids: tp.Optional[chex.Array] = None,
inputs_embeds: tp.Optional[chex.Array] = None,
attention_mask: tp.Optional[chex.Array] = None,
position_ids: tp.Optional[chex.Array] = None,
segment_ids: tp.Optional[chex.Array] = None,
output_attentions: tp.Optional[bool] = None,
output_hidden_states: tp.Optional[bool] = None,
mode: tp.Optional[common_types.RUNTIME_MODE_TYPES] = None, # type:ignore
past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None,
cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None,
) -> CausalLMOutput:
"""Forward pass of the Phi3ForCausalLM model.
Args:
input_ids (tp.Optional[chex.Array]): Input token IDs. Shape: (batch_size, sequence_length).
inputs_embeds (tp.Optional[chex.Array]): Input embeddings. Shape: (batch_size, sequence_length, hidden_size).
Either `input_ids` or `inputs_embeds` must be provided.
attention_mask (tp.Optional[chex.Array]): Mask to avoid performing attention on padding token indices.
Shape: (batch_size, sequence_length).
position_ids (tp.Optional[chex.Array]): Position indices for the tokens.
Shape: (batch_size, sequence_length).
segment_ids (tp.Optional[chex.Array]): Segment IDs (unused).
output_attentions (tp.Optional[bool]): Whether to return attention weights. Defaults to `config.output_attentions`.
output_hidden_states (tp.Optional[bool]): Whether to return hidden states for all layers.
Defaults to `config.output_hidden_states`.
past_key_values (tp.Optional[TransformerCache | PagedAttentionCache]): Precomputed key/value states for attention.
cache_metadata (tp.Optional[TransformerMetadata | PagedAttentionMetadata]): Metadata for paged attention.
Returns:
CausalLMOutput: The model's output.
returns a `CausalLMOutput` object containing `logits`, `hidden_states` (optional),
and `attentions` (optional).
"""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
mode=mode,
past_key_values=past_key_values,
cache_metadata=cache_metadata,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
inputs_embeds=inputs_embeds,
segment_ids=segment_ids,
)
hidden_states = outputs.last_hidden_state
if self.config.tie_word_embeddings:
lm_logits = jax.lax.dot_general(
hidden_states,
self.model.embed_tokens.embedding.value.T,
(((hidden_states.ndim - 1), (0,)), ((), ())),
)
else:
lm_logits = self.lm_head(hidden_states)
return CausalLMOutput(
logits=lm_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
past_key_values=outputs.past_key_values,
)