easydel.modules.phi.modeling_phi_flax

easydel.modules.phi.modeling_phi_flax#

class easydel.modules.phi.modeling_phi_flax.PhiAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

Phi Attention module.

This module implements the multi-head attention mechanism used in the Phi model. It supports Grouped Query Attention (GQA), partial Rotary Position Embeddings (RoPE), and optional Layer Normalization for query and key projections.

config#

Configuration object for the model.

Type

PhiConfig

layer_idx#

Index of the current layer.

Type

int, optional

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

attention_dropout#

Dropout probability for attention scores.

Type

float

hidden_size#

Dimensionality of the hidden states.

Type

int

num_heads#

Number of attention query heads.

Type

int

head_dim#

Dimensionality of each attention head.

Type

int

num_key_value_heads#

Number of attention key/value heads (for GQA).

Type

int

num_key_value_groups#

Number of query head groups for each key/value head.

Type

int

max_position_embeddings#

Maximum sequence length supported by RoPE.

Type

int

rope_theta#

Base value for RoPE frequency calculation.

Type

float

partial_rotary_factor#

Factor determining the fraction of head dimension subject to RoPE.

Type

float

is_causal#

Whether the attention is causal (always True for this implementation).

Type

bool

q_proj#

Linear layer for query projection.

Type

ParallelLinear

k_proj#

Linear layer for key projection.

Type

ParallelLinear

v_proj#

Linear layer for value projection.

Type

ParallelLinear

dense#

Linear layer for the output projection.

Type

ParallelLinear

rotary_emb_dim#

The dimension of the rotary embeddings.

Type

int

qk_layernorm#

Whether to apply LayerNorm to query and key projections.

Type

bool

q_layernorm#

Layer normalization for query projections.

Type

nn.LayerNorm, optional

k_layernorm#

Layer normalization for key projections.

Type

nn.LayerNorm, optional

attention_performer#

Module to perform the core attention computation.

Type

FlexibleAttentionModule

rotary#

Rotary position embedding module.

Type

RoPE

class easydel.modules.phi.modeling_phi_flax.PhiDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Phi Transformer Decoder Layer.

This module represents a single decoder layer in the Phi model, combining self-attention and MLP sub-layers with residual connections and layer normalization.

config#

Configuration object for the model.

Type

PhiConfig

layer_idx#

Index of the current layer.

Type

int, optional

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

input_layernorm#

Layer normalization applied before the attention and MLP blocks.

Type

nn.LayerNorm

resid_dropout#

Dropout applied to the residual connection after the MLP block.

Type

nn.Dropout

self_attn#

The self-attention module.

Type

PhiAttention

mlp#

The feed-forward (MLP) module.

Type

PhiMLP

class easydel.modules.phi.modeling_phi_flax.PhiForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Phi model with a Causal Language Modeling head.

This model consists of the base Phi transformer (PhiModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.

config#

Configuration object for the model.

Type

PhiConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

transformer#

The core Phi transformer model.

Type

PhiModel

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

class easydel.modules.phi.modeling_phi_flax.PhiMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Phi MLP module.

This module implements the feed-forward network (MLP) used in the Phi model. It consists of two linear projections with a GELU activation in between.

config#

Configuration object for the model.

Type

PhiConfig

layer_idx#

Index of the current layer.

Type

int, optional

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

fc1#

First linear projection layer (up-projection).

Type

ParallelLinear

fc2#

Second linear projection layer (down-projection).

Type

ParallelLinear

act#

Activation function.

Type

callable

class easydel.modules.phi.modeling_phi_flax.PhiModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Phi model transformer.

This class represents the core transformer architecture of the Phi model, consisting of an embedding layer, multiple PhiDecoderLayer layers, and a final layer normalization.

config#

Configuration object for the model.

Type

PhiConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[PhiDecoderLayer]

final_layernorm#

Final layer normalization.

Type

nn.LayerNorm

embed_dropout#

Dropout layer applied after embeddings.

Type

nn.Dropout

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray