easydel.modules.phi.modeling_phi#

class easydel.modules.phi.modeling_phi.PhiAttention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

Phi Attention with Q/K normalization.

Inherits Q/K normalization from QKNormAttention. Features: - Uses LayerNorm instead of RMSNorm - Standard LayerNorm on full hidden_size (not per-head) - Partial RoPE (partial_rotary_factor) - Custom bias configuration

norms_mapping: ClassVar[dict[str, str]] = {'key_normalization': 'k_layernorm', 'query_normalization': 'q_layernorm'}#
projection_mapping: ClassVar[dict[str, str]] = {'key_projection': 'k_proj', 'output_projection': 'dense', 'query_projection': 'q_proj', 'value_projection': 'v_proj'}#
class easydel.modules.phi.modeling_phi.PhiDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Phi Transformer Decoder Layer.

This module represents a single decoder layer in the Phi model, combining self-attention and MLP sub-layers with residual connections and layer normalization.

config#

Configuration object for the model.

Type

PhiConfig

layer_idx#

Index of the current layer.

Type

int, optional

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

input_layernorm#

Layer normalization applied before the attention and MLP blocks.

Type

nn.LayerNorm

resid_dropout#

Dropout applied to the residual connection after the MLP block.

Type

nn.Dropout

self_attn#

The self-attention module.

Type

PhiAttention

mlp#

The feed-forward (MLP) module.

Type

PhiMLP

class easydel.modules.phi.modeling_phi.PhiForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[PhiModel, PhiConfig]

Phi model with a Causal Language Modeling head.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module.

class easydel.modules.phi.modeling_phi.PhiMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Phi MLP module.

This module implements the feed-forward network (MLP) used in the Phi model. It consists of two linear projections with a GELU activation in between.

config#

Configuration object for the model.

Type

PhiConfig

layer_idx#

Index of the current layer.

Type

int, optional

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

fc1#

First linear projection layer (up-projection).

Type

ParallelLinear

fc2#

Second linear projection layer (down-projection).

Type

ParallelLinear

act#

Activation function.

Type

callable

class easydel.modules.phi.modeling_phi.PhiModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Phi model transformer.

This class represents the core transformer architecture of the Phi model, consisting of an embedding layer, multiple PhiDecoderLayer layers, and a final layer normalization.

config#

Configuration object for the model.

Type

PhiConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[PhiDecoderLayer]

final_layernorm#

Final layer normalization.

Type

nn.LayerNorm

embed_dropout#

Dropout layer applied after embeddings.

Type

nn.Dropout

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.