easydel.modules.phi3.__init__#

class easydel.modules.phi3.__init__.Phi3Config(vocab_size=32064, hidden_size=3072, intermediate_size=8192, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, resid_pdrop=0.0, embd_pdrop=0.0, attention_dropout=0.0, hidden_act='silu', max_position_embeddings=4096, original_max_position_embeddings=4096, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, bos_token_id=1, eos_token_id=32000, pad_token_id=32000, sliding_window=None, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32064) – Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 3072) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 8192) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional) – Number of key and value heads for each attention layer in the Transformer encoder. Will default to num_attention_heads if not set.

  • resid_pdrop (float, optional, defaults to 0.0) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • embd_pdrop (float, optional, defaults to 0.0) – The dropout ratio for the embeddings.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 4096) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • original_max_position_embeddings (int, optional, defaults to 4096) – The original maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 10000.0) – The theta value to use for rotary position embeddings.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 32000) – The id of the end-of-sequence token.

  • pad_token_id (int, optional, defaults to 32000) – The index of the padding token in the vocabulary.

  • sliding_window (int, optional) – The sliding window size.

  • bits (int, optional) – The number of bits to quantize the model to.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size specifically for frequency-based position embeddings.

If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for frequency encoding.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size specifically for mask-based position embeddings.

If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for mask encoding.

Return type

int

model_type: str = 'phi3'#
class easydel.modules.phi3.__init__.Phi3ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Phi-3 model with a Causal Language Modeling head.

This model consists of the base Phi-3 transformer (Phi3Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.

config#

Configuration object for the model.

Type

Phi3Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Phi-3 transformer model.

Type

Phi3Model

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

class easydel.modules.phi3.__init__.Phi3Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Phi-3 model transformer.

This class represents the core transformer architecture of the Phi-3 model, consisting of an embedding layer, multiple Phi3DecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Phi3Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

embed_dropout#

Dropout layer applied after embeddings.

Type

nn.Dropout

layers#

List of decoder layers.

Type

tp.List[Phi3DecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray