easydel.modules.mosaic_mpt.modeling_mosaic#

class easydel.modules.mosaic_mpt.modeling_mosaic.MptAttention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

MPT Attention module with ALiBi positional bias.

Inherits from UnifiedAttention. Uses fused QKV projection and ALiBi (Attention with Linear Biases) for positional information. Overrides forward_alibi to handle custom ALiBi bias computation with masking.

config#

Configuration object for the model.

Type

MptConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

Wqkv#

Fused linear layer for query, key, and value projections.

Type

ColumnParallelLinear

out_proj#

Linear layer for the output projection.

Type

RowParallelLinear

resid_dropout#

Dropout layer applied after the output projection.

Type

nn.Dropout

define_network(config: MptConfig, dtype: dtype, param_dtype: dtype, precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset], rngs: Rngs)[source]#

Define MPT-specific network with fused QKV projection.

forward_alibi(hidden_states: Float[Array, 'batch seq_len hidden_dim'], mask_info: ejkernel.types.mask.MaskInfo | None, position_ids: Int[Array, 'batch seq_len'], mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], cache_view: easydel.layers.caching.transformer.cache.TransformerCacheView | easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView | None = None, cache_metadata: easydel.layers.caching.transformer.cache.TransformerMetadata | easydel.layers.caching.ragged_page.cache.RaggedPagesMetadata | None = None, output_attentions: bool = False, alibi: jaxtyping.Float[Array, 'batch_or_1 heads qseq_len_or_1 kvseq_len_or_1'] | None = None) AttentionLayerOutput[source]#

Override ALiBi forward with MPT’s custom bias computation and masking.

class easydel.modules.mosaic_mpt.modeling_mosaic.MptBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

MPT Transformer block.

This module represents a single transformer block in the MPT model, containing self-attention and MLP sub-layers with residual connections and layer normalization. It utilizes ALiBi for positional information.

config#

Configuration object for the model.

Type

MptConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

norm_1#

Layer normalization before the attention layer.

Type

nn.LayerNorm

attn#

The self-attention module.

Type

MptAttention

norm_2#

Layer normalization before the MLP layer.

Type

nn.LayerNorm

ffn#

The feed-forward (MLP) module.

Type

MptMLP

resid_attn_dropout#

Dropout applied after the attention layer’s residual connection.

Type

nn.Dropout

class easydel.modules.mosaic_mpt.modeling_mosaic.MptForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[MptModel, MptConfig]

MPT model with a language modeling head.

class easydel.modules.mosaic_mpt.modeling_mosaic.MptMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

MPT MLP module.

This module implements the feed-forward network (MLP) used in the MPT model. It consists of an up-projection, GELU activation, and a down-projection, followed by dropout.

config#

Configuration object for the model.

Type

MptConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

up_proj#

Linear layer for up-projection.

Type

ParallelLinear

down_proj#

Linear layer for down-projection.

Type

ParallelLinear

hidden_dropout#

Dropout layer applied to the output.

Type

nn.Dropout

class easydel.modules.mosaic_mpt.modeling_mosaic.MptModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

MPT model implementation.

This class implements the main MPT transformer model architecture, consisting of an embedding layer (token and optional positional), multiple MptBlock layers, and a final layer normalization.

config#

Configuration object for the model.

Type

MptConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

wte#

Token embedding layer.

Type

nn.Embed

emb_drop#

Dropout layer applied after embeddings.

Type

nn.Dropout

blocks#

List of transformer blocks.

Type

tp.List[MptBlock]

norm_f#

Final layer normalization.

Type

nn.LayerNorm

alibi#

Precomputed ALiBi tensor if using ALiBi.

Type

chex.Array, optional

property alibi#
get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.

easydel.modules.mosaic_mpt.modeling_mosaic.build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8)[source]#

Builds the ALiBi tensor for MPT models.

ALiBi (Attention with Linear Biases) is a method to incorporate positional information into transformer models without explicit position embeddings. It adds a bias to the attention scores based on the distance between query and key positions.

Parameters
  • num_heads (int) – The number of attention heads.

  • sequence_length (int) – The length of the sequence.

  • alibi_bias_max (int, optional) – The maximum bias value allowed by ALiBi. Defaults to 8.

Returns

The ALiBi tensor of shape (1, num_heads, sequence_length, sequence_length).

Return type

chex.Array