easydel.modules.dbrx.modeling_dbrx#

class easydel.modules.dbrx.modeling_dbrx.DbrxAttention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

DBRX Attention module with fused QKV projection.

This module implements the multi-head attention mechanism used in the DBRX model. It supports Grouped Query Attention (GQA) and Rotary Position Embeddings (RoPE). The query, key, and value projections are combined into a single fused linear layer for efficiency, and supports optional QKV clipping.

Overrides forward_standard to efficiently handle fused QKV projection.

config#

Configuration object for the model.

Type

DbrxConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

Wqkv#

Fused linear layer for query, key, and value projections.

Type

ColumnParallelLinear

out_proj#

Linear layer for the output projection.

Type

RowParallelLinear

attention_performer#

Module to perform the core attention computation.

Type

FlexibleAttentionModule

rotary#

Rotary position embedding module.

Type

RoPE

resid_dropout#

Residual dropout layer.

Type

nn.Dropout

define_network(config: DbrxConfig, dtype: dtype, param_dtype: dtype, precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset], rngs: Rngs)[source]#

Override to create fused QKV projection instead of separate Q/K/V.

Parameters
  • config – Model configuration

  • dtype – Data type for computations

  • param_dtype – Data type for parameters

  • precision – JAX precision setting

  • rngs – Random number generators

forward(hidden_states: Float[Array, 'batch seq_len hidden_dim'], mask_info: ejkernel.types.mask.MaskInfo | None, position_ids: Int[Array, 'batch seq_len'], mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], cache_view: easydel.layers.caching.transformer.cache.TransformerCacheView | easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView | None = None, cache_metadata: easydel.layers.caching.transformer.cache.TransformerMetadata | easydel.layers.caching.ragged_page.cache.RaggedPagesMetadata | None = None, output_attentions: bool = False, frequencies: jaxtyping.Float[Array, 'seq_len head_dim'] | None = None, alibi: jaxtyping.Float[Array, 'batch_or_1 heads qseq_len_or_1 kvseq_len_or_1'] | None = None)[source]#

Override to handle fused QKV projection efficiently with optional clipping.

projection_mapping: ClassVar[dict[str, str]] = {'output_projection': 'out_proj', 'query_key_value_projection': 'Wqkv'}#
class easydel.modules.dbrx.modeling_dbrx.DbrxBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

Single transformer block for DBRX models.

Integrates attention mechanisms with mixture of experts feedforward networks, using residual connections and normalization.

class easydel.modules.dbrx.modeling_dbrx.DbrxExpertGLU(*args: Any, **kwargs: Any)[source]#

Bases: Module

Gated Linear Unit expert module for DBRX mixture of experts.

Implements a single expert network with gated activation for specialized processing in the MoE architecture.

class easydel.modules.dbrx.modeling_dbrx.DbrxExperts(*args: Any, **kwargs: Any)[source]#

Bases: Module

Collection of expert networks for DBRX mixture of experts.

Manages multiple expert networks that can be selected and combined based on routing decisions for conditional computation.

class easydel.modules.dbrx.modeling_dbrx.DbrxFFN(*args: Any, **kwargs: Any)[source]#

Bases: Module

Feedforward network with mixture of experts for DBRX models.

Combines router and expert networks to implement sparse MoE feedforward layers with conditional computation.

class easydel.modules.dbrx.modeling_dbrx.DbrxForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[DbrxModel, DbrxConfig]

DBRX model with a Causal Language Modeling head.

class easydel.modules.dbrx.modeling_dbrx.DbrxForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: BaseSequenceClassificationModule[DbrxModel, DbrxConfig]

DBRX model with a Sequence Classification head.

class easydel.modules.dbrx.modeling_dbrx.DbrxModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Base DBRX Model outputting raw hidden-states.

This model is a Transformer-based model with a mixture of experts (MoE) architecture, implementing the DBRX architecture as described in the original paper.

The model uses specialized attention modules and a router-based MoE FFN layer.

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

get_decoder() Module[source]#

Returns the decoder part of the model’s graph definition. For DbrxModel, this is the model itself.

get_embedding() Module[source]#

Returns the embedding layer of the module.

get_encoder() Module[source]#

Returns the encoder part of the model’s graph definition. For DbrxModel (decoder-only), this is not applicable.

get_lm_head() Module[source]#

Returns the language model head of the module. DbrxModel does not include the lm_head.

class easydel.modules.dbrx.modeling_dbrx.DbrxNormAttentionNorm(*args: Any, **kwargs: Any)[source]#

Bases: Module

Normalization-Attention-Normalization module for DBRX models.

Implements a unique architecture pattern with normalization layers surrounding the attention mechanism for improved gradient flow.

static kernel_init(key: Array, shape: Sequence[Union[int, Any]], dtype: Any | None = None, out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None) Array#

An initializer that returns a constant array full of ones.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)
class easydel.modules.dbrx.modeling_dbrx.DbrxRouter(*args: Any, **kwargs: Any)[source]#

Bases: Module

Router module for DBRX mixture of experts.

Determines which experts to activate for each input token, implementing sparse routing for efficient computation.

jitter(x: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#