easydel.modules.mosaic_mpt.modeling_mosaic_flax#
- class easydel.modules.mosaic_mpt.modeling_mosaic_flax.MptAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModuleMPT Attention module.
This module implements the multi-head attention mechanism used in the MPT model. It supports ALiBi positional bias and allows for different attention implementations.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
Dimensionality of the hidden states.
- Type
int
- Wqkv#
Combined linear layer for query, key, and value projections.
- Type
- out_proj#
Linear layer for the output projection.
- Type
- dropout#
Dropout layer applied after the output projection.
- Type
nn.Dropout
- n_heads#
Number of attention heads.
- Type
int
- max_seq_length#
Maximum sequence length supported.
- Type
int
- head_dim#
Dimensionality of each attention head.
- Type
int
- softmax_scale#
Scale factor for the softmax function.
- Type
float
- attention_performer#
Module to perform the core attention computation.
- class easydel.modules.mosaic_mpt.modeling_mosaic_flax.MptBlock(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMPT Transformer block.
This module represents a single transformer block in the MPT model, containing self-attention and MLP sub-layers with residual connections and layer normalization. It utilizes ALiBi for positional information.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- norm_1#
Layer normalization before the attention layer.
- Type
nn.LayerNorm
- attn#
The self-attention module.
- Type
- norm_2#
Layer normalization before the MLP layer.
- Type
nn.LayerNorm
- resid_attn_dropout#
Dropout applied after the attention layer’s residual connection.
- Type
nn.Dropout
- class easydel.modules.mosaic_mpt.modeling_mosaic_flax.MptForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleMPT model with a language modeling head.
This model extends the base MptModel by adding a linear layer (lm_head) on top to predict the next token in a sequence, making it suitable for causal language modeling tasks.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- lm_head#
The language modeling head. If use_lm_head in the config is True (tying embeddings), this will be None.
- Type
ParallelLinear, optional
- class easydel.modules.mosaic_mpt.modeling_mosaic_flax.MptMLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMPT MLP module.
This module implements the feed-forward network (MLP) used in the MPT model. It consists of an up-projection, GELU activation, and a down-projection, followed by dropout.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- up_proj#
Linear layer for up-projection.
- Type
- down_proj#
Linear layer for down-projection.
- Type
Dropout layer applied to the output.
- Type
nn.Dropout
- class easydel.modules.mosaic_mpt.modeling_mosaic_flax.MptModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleMPT model implementation.
This class implements the main MPT transformer model architecture, consisting of an embedding layer (token and optional positional), multiple MptBlock layers, and a final layer normalization.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- wte#
Token embedding layer.
- Type
nn.Embed
- emb_drop#
Dropout layer applied after embeddings.
- Type
nn.Dropout
- norm_f#
Final layer normalization.
- Type
nn.LayerNorm
- alibi#
Precomputed ALiBi tensor if using ALiBi.
- Type
chex.Array, optional
- property alibi#
- easydel.modules.mosaic_mpt.modeling_mosaic_flax.build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8)[source]#
Builds the ALiBi tensor for MPT models.
ALiBi (Attention with Linear Biases) is a method to incorporate positional information into transformer models without explicit position embeddings. It adds a bias to the attention scores based on the distance between query and key positions.
- Parameters
num_heads (int) – The number of attention heads.
sequence_length (int) – The length of the sequence.
alibi_bias_max (int, optional) – The maximum bias value allowed by ALiBi. Defaults to 8.
- Returns
The ALiBi tensor of shape (1, num_heads, sequence_length, sequence_length).
- Return type
chex.Array