easydel.modules.phi3.modeling_phi3_flax#
- class easydel.modules.phi3.modeling_phi3_flax.Phi3Attention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModulePhi3 Attention module.
This module implements the multi-head attention mechanism used in the Phi-3 model. It supports Grouped Query Attention (GQA) and Rotary Position Embeddings (RoPE). The query, key, and value projections are combined into a single linear layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- attention_dropout#
Dropout probability for attention scores.
- Type
float
Dimensionality of the hidden states.
- Type
int
- num_heads#
Number of attention query heads.
- Type
int
- head_dim#
Dimensionality of each attention head.
- Type
int
- num_key_value_heads#
Number of attention key/value heads (for GQA).
- Type
int
- num_key_value_groups#
Number of query head groups for each key/value head.
- Type
int
- max_position_embeddings#
Maximum sequence length supported by RoPE.
- Type
int
- original_max_position_embeddings#
Original max sequence length for RoPE scaling.
- Type
int
- is_causal#
Whether the attention is causal (always True for this implementation).
- Type
bool
- o_proj#
Linear layer for the output projection.
- Type
- qkv_proj#
Combined linear layer for query, key, and value projections.
- Type
- attention_performer#
Module to perform the core attention computation.
- rotary#
Rotary position embedding module.
- Type
RoPE
- class easydel.modules.phi3.modeling_phi3_flax.Phi3DecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePhi3 Transformer Decoder Layer.
This module represents a single decoder layer in the Phi-3 model, combining self-attention and MLP sub-layers with residual connections and RMS normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- self_attn#
The self-attention module.
- Type
- post_attention_layernorm#
RMS normalization applied after the attention layer and before the MLP layer.
- Type
- dropout#
Dropout layer applied to the residual connections.
- Type
nn.Dropout
- class easydel.modules.phi3.modeling_phi3_flax.Phi3ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModulePhi-3 model with a Causal Language Modeling head.
This model consists of the base Phi-3 transformer (Phi3Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.modules.phi3.modeling_phi3_flax.Phi3MLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePhi3 MLP module.
This module implements the feed-forward network (MLP) used in the Phi-3 model. It consists of a combined gate and up projection, SiLU activation, and a down projection.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- gate_up_proj#
Combined linear layer for gate and up projections.
- Type
- down_proj#
Linear layer for the down projection.
- Type
- activation_fn#
Activation function (SiLU).
- Type
callable
- class easydel.modules.phi3.modeling_phi3_flax.Phi3Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Phi-3 model transformer.
This class represents the core transformer architecture of the Phi-3 model, consisting of an embedding layer, multiple Phi3DecoderLayer layers, and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- embed_dropout#
Dropout layer applied after embeddings.
- Type
nn.Dropout
- layers#
List of decoder layers.
- Type
tp.List[Phi3DecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.
- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray