easydel.modules.phi3.modeling_phi3#
- class easydel.modules.phi3.modeling_phi3.Phi3Attention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionPhi3 Attention module with fused QKV projection.
This module implements the multi-head attention mechanism used in the Phi-3 model. It supports Grouped Query Attention (GQA), Rotary Position Embeddings (RoPE), and sliding window attention. The query, key, and value projections are combined into a single fused linear layer for efficiency.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- sliding_window#
Sliding window size for local attention.
- Type
int
- qkv_proj#
Fused linear layer for query, key, and value projections.
- Type
- o_proj#
Linear layer for the output projection.
- Type
- attention_performer#
Module to perform the core attention computation.
- rotary#
Rotary position embedding module with partial RoPE support.
- Type
RoPE
- define_network(config: Phi3Config, dtype: dtype, param_dtype: dtype, precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset], rngs: Rngs)[source]#
Override to create fused QKV projection instead of separate Q/K/V.
- Parameters
config – Model configuration
dtype – Data type for computations
param_dtype – Data type for parameters
precision – JAX precision setting
rngs – Random number generators
- projection_mapping: ClassVar[dict[str, str]] = {'output_projection': 'o_proj', 'query_key_value_projection': 'qkv_proj'}#
- class easydel.modules.phi3.modeling_phi3.Phi3DecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePhi3 Transformer Decoder Layer.
This module represents a single decoder layer in the Phi-3 model, combining self-attention and MLP sub-layers with residual connections and RMS normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- self_attn#
The self-attention module.
- Type
- post_attention_layernorm#
RMS normalization applied after the attention layer and before the MLP layer.
- Type
- dropout#
Dropout layer applied to the residual connections.
- Type
nn.Dropout
- class easydel.modules.phi3.modeling_phi3.Phi3ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
BaseCausalLMModule[Phi3Model,Phi3Config]Phi-3 model with a Causal Language Modeling head.
- class easydel.modules.phi3.modeling_phi3.Phi3MLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePhi3 MLP module.
This module implements the feed-forward network (MLP) used in the Phi-3 model. It consists of a combined gate and up projection, SiLU activation, and a down projection.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- gate_up_proj#
Combined linear layer for gate and up projections.
- Type
- down_proj#
Linear layer for the down projection.
- Type
- activation_fn#
Activation function (SiLU).
- Type
callable
- class easydel.modules.phi3.modeling_phi3.Phi3Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Phi-3 model transformer.
This class represents the core transformer architecture of the Phi-3 model, consisting of an embedding layer, multiple Phi3DecoderLayer layers, and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- embed_dropout#
Dropout layer applied after embeddings.
- Type
nn.Dropout
- layers#
List of decoder layers.
- Type
tp.List[Phi3DecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.
- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray