easydel.modules.mosaic_mpt.modeling_mosaic#
- class easydel.modules.mosaic_mpt.modeling_mosaic.MptAttention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionMPT Attention module with ALiBi positional bias.
Inherits from UnifiedAttention. Uses fused QKV projection and ALiBi (Attention with Linear Biases) for positional information. Overrides forward_alibi to handle custom ALiBi bias computation with masking.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- Wqkv#
Fused linear layer for query, key, and value projections.
- Type
- out_proj#
Linear layer for the output projection.
- Type
- resid_dropout#
Dropout layer applied after the output projection.
- Type
nn.Dropout
- define_network(config: MptConfig, dtype: dtype, param_dtype: dtype, precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset], rngs: Rngs)[source]#
Define MPT-specific network with fused QKV projection.
- forward_alibi(hidden_states: Float[Array, 'batch seq_len hidden_dim'], mask_info: ejkernel.types.mask.MaskInfo | None, position_ids: Int[Array, 'batch seq_len'], mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], cache_view: easydel.layers.caching.transformer.cache.TransformerCacheView | easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView | None = None, cache_metadata: easydel.layers.caching.transformer.cache.TransformerMetadata | easydel.layers.caching.ragged_page.cache.RaggedPagesMetadata | None = None, output_attentions: bool = False, alibi: jaxtyping.Float[Array, 'batch_or_1 heads qseq_len_or_1 kvseq_len_or_1'] | None = None) AttentionLayerOutput[source]#
Override ALiBi forward with MPT’s custom bias computation and masking.
- class easydel.modules.mosaic_mpt.modeling_mosaic.MptBlock(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMPT Transformer block.
This module represents a single transformer block in the MPT model, containing self-attention and MLP sub-layers with residual connections and layer normalization. It utilizes ALiBi for positional information.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- norm_1#
Layer normalization before the attention layer.
- Type
nn.LayerNorm
- attn#
The self-attention module.
- Type
- norm_2#
Layer normalization before the MLP layer.
- Type
nn.LayerNorm
- resid_attn_dropout#
Dropout applied after the attention layer’s residual connection.
- Type
nn.Dropout
- class easydel.modules.mosaic_mpt.modeling_mosaic.MptForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
BaseCausalLMModule[MptModel,MptConfig]MPT model with a language modeling head.
- class easydel.modules.mosaic_mpt.modeling_mosaic.MptMLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMPT MLP module.
This module implements the feed-forward network (MLP) used in the MPT model. It consists of an up-projection, GELU activation, and a down-projection, followed by dropout.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- up_proj#
Linear layer for up-projection.
- Type
- down_proj#
Linear layer for down-projection.
- Type
Dropout layer applied to the output.
- Type
nn.Dropout
- class easydel.modules.mosaic_mpt.modeling_mosaic.MptModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleMPT model implementation.
This class implements the main MPT transformer model architecture, consisting of an embedding layer (token and optional positional), multiple MptBlock layers, and a final layer normalization.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- wte#
Token embedding layer.
- Type
nn.Embed
- emb_drop#
Dropout layer applied after embeddings.
- Type
nn.Dropout
- norm_f#
Final layer normalization.
- Type
nn.LayerNorm
- alibi#
Precomputed ALiBi tensor if using ALiBi.
- Type
chex.Array, optional
- property alibi#
- easydel.modules.mosaic_mpt.modeling_mosaic.build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8)[source]#
Builds the ALiBi tensor for MPT models.
ALiBi (Attention with Linear Biases) is a method to incorporate positional information into transformer models without explicit position embeddings. It adds a bias to the attention scores based on the distance between query and key positions.
- Parameters
num_heads (int) – The number of attention heads.
sequence_length (int) – The length of the sequence.
alibi_bias_max (int, optional) – The maximum bias value allowed by ALiBi. Defaults to 8.
- Returns
The ALiBi tensor of shape (1, num_heads, sequence_length, sequence_length).
- Return type
chex.Array