easydel.modules.phi3.modeling_phi3#

class easydel.modules.phi3.modeling_phi3.Phi3Attention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

Phi3 Attention module with fused QKV projection.

This module implements the multi-head attention mechanism used in the Phi-3 model. It supports Grouped Query Attention (GQA), Rotary Position Embeddings (RoPE), and sliding window attention. The query, key, and value projections are combined into a single fused linear layer for efficiency.

config#

Configuration object for the model.

Type

Phi3Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

sliding_window#

Sliding window size for local attention.

Type

int

qkv_proj#

Fused linear layer for query, key, and value projections.

Type

ColumnParallelLinear

o_proj#

Linear layer for the output projection.

Type

RowParallelLinear

attention_performer#

Module to perform the core attention computation.

Type

FlexibleAttentionModule

rotary#

Rotary position embedding module with partial RoPE support.

Type

RoPE

define_network(config: Phi3Config, dtype: dtype, param_dtype: dtype, precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset], rngs: Rngs)[source]#

Override to create fused QKV projection instead of separate Q/K/V.

Parameters
  • config – Model configuration

  • dtype – Data type for computations

  • param_dtype – Data type for parameters

  • precision – JAX precision setting

  • rngs – Random number generators

projection_mapping: ClassVar[dict[str, str]] = {'output_projection': 'o_proj', 'query_key_value_projection': 'qkv_proj'}#
class easydel.modules.phi3.modeling_phi3.Phi3DecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Phi3 Transformer Decoder Layer.

This module represents a single decoder layer in the Phi-3 model, combining self-attention and MLP sub-layers with residual connections and RMS normalization.

config#

Configuration object for the model.

Type

Phi3Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

input_layernorm#

RMS normalization applied before the attention layer.

Type

RMSNorm

self_attn#

The self-attention module.

Type

Phi3Attention

mlp#

The feed-forward (MLP) module.

Type

Phi3MLP

post_attention_layernorm#

RMS normalization applied after the attention layer and before the MLP layer.

Type

RMSNorm

dropout#

Dropout layer applied to the residual connections.

Type

nn.Dropout

class easydel.modules.phi3.modeling_phi3.Phi3ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[Phi3Model, Phi3Config]

Phi-3 model with a Causal Language Modeling head.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module.

class easydel.modules.phi3.modeling_phi3.Phi3MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Phi3 MLP module.

This module implements the feed-forward network (MLP) used in the Phi-3 model. It consists of a combined gate and up projection, SiLU activation, and a down projection.

config#

Configuration object for the model.

Type

Phi3Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

gate_up_proj#

Combined linear layer for gate and up projections.

Type

ParallelLinear

down_proj#

Linear layer for the down projection.

Type

ParallelLinear

activation_fn#

Activation function (SiLU).

Type

callable

class easydel.modules.phi3.modeling_phi3.Phi3Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Phi-3 model transformer.

This class represents the core transformer architecture of the Phi-3 model, consisting of an embedding layer, multiple Phi3DecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Phi3Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

embed_dropout#

Dropout layer applied after embeddings.

Type

nn.Dropout

layers#

List of decoder layers.

Type

tp.List[Phi3DecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.