easydel.modules.mosaic_mpt.modeling_mosaic_flax#

class easydel.modules.mosaic_mpt.modeling_mosaic_flax.MptAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

MPT Attention module.

This module implements the multi-head attention mechanism used in the MPT model. It supports ALiBi positional bias and allows for different attention implementations.

config#

Configuration object for the model.

Type

MptConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

hidden_size#

Dimensionality of the hidden states.

Type

int

Wqkv#

Combined linear layer for query, key, and value projections.

Type

ParallelLinear

out_proj#

Linear layer for the output projection.

Type

ParallelLinear

dropout#

Dropout layer applied after the output projection.

Type

nn.Dropout

n_heads#

Number of attention heads.

Type

int

max_seq_length#

Maximum sequence length supported.

Type

int

head_dim#

Dimensionality of each attention head.

Type

int

softmax_scale#

Scale factor for the softmax function.

Type

float

attention_performer#

Module to perform the core attention computation.

Type

FlexibleAttentionModule

class easydel.modules.mosaic_mpt.modeling_mosaic_flax.MptBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

MPT Transformer block.

This module represents a single transformer block in the MPT model, containing self-attention and MLP sub-layers with residual connections and layer normalization. It utilizes ALiBi for positional information.

config#

Configuration object for the model.

Type

MptConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

norm_1#

Layer normalization before the attention layer.

Type

nn.LayerNorm

attn#

The self-attention module.

Type

MptAttention

norm_2#

Layer normalization before the MLP layer.

Type

nn.LayerNorm

ffn#

The feed-forward (MLP) module.

Type

MptMLP

resid_attn_dropout#

Dropout applied after the attention layer’s residual connection.

Type

nn.Dropout

class easydel.modules.mosaic_mpt.modeling_mosaic_flax.MptForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

MPT model with a language modeling head.

This model extends the base MptModel by adding a linear layer (lm_head) on top to predict the next token in a sequence, making it suitable for causal language modeling tasks.

config#

Configuration object for the model.

Type

MptConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

transformer#

The core MPT transformer model.

Type

MptModel

lm_head#

The language modeling head. If use_lm_head in the config is True (tying embeddings), this will be None.

Type

ParallelLinear, optional

class easydel.modules.mosaic_mpt.modeling_mosaic_flax.MptMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

MPT MLP module.

This module implements the feed-forward network (MLP) used in the MPT model. It consists of an up-projection, GELU activation, and a down-projection, followed by dropout.

config#

Configuration object for the model.

Type

MptConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

up_proj#

Linear layer for up-projection.

Type

ParallelLinear

down_proj#

Linear layer for down-projection.

Type

ParallelLinear

hidden_dropout#

Dropout layer applied to the output.

Type

nn.Dropout

class easydel.modules.mosaic_mpt.modeling_mosaic_flax.MptModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

MPT model implementation.

This class implements the main MPT transformer model architecture, consisting of an embedding layer (token and optional positional), multiple MptBlock layers, and a final layer normalization.

config#

Configuration object for the model.

Type

MptConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

wte#

Token embedding layer.

Type

nn.Embed

emb_drop#

Dropout layer applied after embeddings.

Type

nn.Dropout

blocks#

List of transformer blocks.

Type

tp.List[MptBlock]

norm_f#

Final layer normalization.

Type

nn.LayerNorm

alibi#

Precomputed ALiBi tensor if using ALiBi.

Type

chex.Array, optional

property alibi#
easydel.modules.mosaic_mpt.modeling_mosaic_flax.build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8)[source]#

Builds the ALiBi tensor for MPT models.

ALiBi (Attention with Linear Biases) is a method to incorporate positional information into transformer models without explicit position embeddings. It adds a bias to the attention scores based on the distance between query and key positions.

Parameters
  • num_heads (int) – The number of attention heads.

  • sequence_length (int) – The length of the sequence.

  • alibi_bias_max (int, optional) – The maximum bias value allowed by ALiBi. Defaults to 8.

Returns

The ALiBi tensor of shape (1, num_heads, sequence_length, sequence_length).

Return type

chex.Array