easydel.modules.phi.modeling_phi_flax#
- class easydel.modules.phi.modeling_phi_flax.PhiAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModulePhi Attention module.
This module implements the multi-head attention mechanism used in the Phi model. It supports Grouped Query Attention (GQA), partial Rotary Position Embeddings (RoPE), and optional Layer Normalization for query and key projections.
- layer_idx#
Index of the current layer.
- Type
int, optional
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- attention_dropout#
Dropout probability for attention scores.
- Type
float
Dimensionality of the hidden states.
- Type
int
- num_heads#
Number of attention query heads.
- Type
int
- head_dim#
Dimensionality of each attention head.
- Type
int
- num_key_value_heads#
Number of attention key/value heads (for GQA).
- Type
int
- num_key_value_groups#
Number of query head groups for each key/value head.
- Type
int
- max_position_embeddings#
Maximum sequence length supported by RoPE.
- Type
int
- rope_theta#
Base value for RoPE frequency calculation.
- Type
float
- partial_rotary_factor#
Factor determining the fraction of head dimension subject to RoPE.
- Type
float
- is_causal#
Whether the attention is causal (always True for this implementation).
- Type
bool
- q_proj#
Linear layer for query projection.
- Type
- k_proj#
Linear layer for key projection.
- Type
- v_proj#
Linear layer for value projection.
- Type
- dense#
Linear layer for the output projection.
- Type
- rotary_emb_dim#
The dimension of the rotary embeddings.
- Type
int
- qk_layernorm#
Whether to apply LayerNorm to query and key projections.
- Type
bool
- q_layernorm#
Layer normalization for query projections.
- Type
nn.LayerNorm, optional
- k_layernorm#
Layer normalization for key projections.
- Type
nn.LayerNorm, optional
- attention_performer#
Module to perform the core attention computation.
- rotary#
Rotary position embedding module.
- Type
RoPE
- class easydel.modules.phi.modeling_phi_flax.PhiDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePhi Transformer Decoder Layer.
This module represents a single decoder layer in the Phi model, combining self-attention and MLP sub-layers with residual connections and layer normalization.
- layer_idx#
Index of the current layer.
- Type
int, optional
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- input_layernorm#
Layer normalization applied before the attention and MLP blocks.
- Type
nn.LayerNorm
- resid_dropout#
Dropout applied to the residual connection after the MLP block.
- Type
nn.Dropout
- self_attn#
The self-attention module.
- Type
- class easydel.modules.phi.modeling_phi_flax.PhiForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModulePhi model with a Causal Language Modeling head.
This model consists of the base Phi transformer (PhiModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.modules.phi.modeling_phi_flax.PhiMLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePhi MLP module.
This module implements the feed-forward network (MLP) used in the Phi model. It consists of two linear projections with a GELU activation in between.
- layer_idx#
Index of the current layer.
- Type
int, optional
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- fc1#
First linear projection layer (up-projection).
- Type
- fc2#
Second linear projection layer (down-projection).
- Type
- act#
Activation function.
- Type
callable
- class easydel.modules.phi.modeling_phi_flax.PhiModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Phi model transformer.
This class represents the core transformer architecture of the Phi model, consisting of an embedding layer, multiple PhiDecoderLayer layers, and a final layer normalization.
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[PhiDecoderLayer]
- final_layernorm#
Final layer normalization.
- Type
nn.LayerNorm
- embed_dropout#
Dropout layer applied after embeddings.
- Type
nn.Dropout
- gradient_checkpointing#
Gradient checkpointing configuration.
- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray