easydel.modules.phi.modeling_phi#
- class easydel.modules.phi.modeling_phi.PhiAttention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionPhi Attention with Q/K normalization.
Inherits Q/K normalization from QKNormAttention. Features: - Uses LayerNorm instead of RMSNorm - Standard LayerNorm on full hidden_size (not per-head) - Partial RoPE (partial_rotary_factor) - Custom bias configuration
- norms_mapping: ClassVar[dict[str, str]] = {'key_normalization': 'k_layernorm', 'query_normalization': 'q_layernorm'}#
- projection_mapping: ClassVar[dict[str, str]] = {'key_projection': 'k_proj', 'output_projection': 'dense', 'query_projection': 'q_proj', 'value_projection': 'v_proj'}#
- class easydel.modules.phi.modeling_phi.PhiDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePhi Transformer Decoder Layer.
This module represents a single decoder layer in the Phi model, combining self-attention and MLP sub-layers with residual connections and layer normalization.
- layer_idx#
Index of the current layer.
- Type
int, optional
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- input_layernorm#
Layer normalization applied before the attention and MLP blocks.
- Type
nn.LayerNorm
- resid_dropout#
Dropout applied to the residual connection after the MLP block.
- Type
nn.Dropout
- self_attn#
The self-attention module.
- Type
- class easydel.modules.phi.modeling_phi.PhiForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
BaseCausalLMModule[PhiModel,PhiConfig]Phi model with a Causal Language Modeling head.
- class easydel.modules.phi.modeling_phi.PhiMLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePhi MLP module.
This module implements the feed-forward network (MLP) used in the Phi model. It consists of two linear projections with a GELU activation in between.
- layer_idx#
Index of the current layer.
- Type
int, optional
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- fc1#
First linear projection layer (up-projection).
- Type
- fc2#
Second linear projection layer (down-projection).
- Type
- act#
Activation function.
- Type
callable
- class easydel.modules.phi.modeling_phi.PhiModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Phi model transformer.
This class represents the core transformer architecture of the Phi model, consisting of an embedding layer, multiple PhiDecoderLayer layers, and a final layer normalization.
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[PhiDecoderLayer]
- final_layernorm#
Final layer normalization.
- Type
nn.LayerNorm
- embed_dropout#
Dropout layer applied after embeddings.
- Type
nn.Dropout
- gradient_checkpointing#
Gradient checkpointing configuration.
- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray