easydel.modules.dbrx.modeling_dbrx#
- class easydel.modules.dbrx.modeling_dbrx.DbrxAttention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionDBRX Attention module with fused QKV projection.
This module implements the multi-head attention mechanism used in the DBRX model. It supports Grouped Query Attention (GQA) and Rotary Position Embeddings (RoPE). The query, key, and value projections are combined into a single fused linear layer for efficiency, and supports optional QKV clipping.
Overrides forward_standard to efficiently handle fused QKV projection.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- Wqkv#
Fused linear layer for query, key, and value projections.
- Type
- out_proj#
Linear layer for the output projection.
- Type
- attention_performer#
Module to perform the core attention computation.
- rotary#
Rotary position embedding module.
- Type
RoPE
- resid_dropout#
Residual dropout layer.
- Type
nn.Dropout
- define_network(config: DbrxConfig, dtype: dtype, param_dtype: dtype, precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset], rngs: Rngs)[source]#
Override to create fused QKV projection instead of separate Q/K/V.
- Parameters
config – Model configuration
dtype – Data type for computations
param_dtype – Data type for parameters
precision – JAX precision setting
rngs – Random number generators
- forward(hidden_states: Float[Array, 'batch seq_len hidden_dim'], mask_info: ejkernel.types.mask.MaskInfo | None, position_ids: Int[Array, 'batch seq_len'], mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], cache_view: easydel.layers.caching.transformer.cache.TransformerCacheView | easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView | None = None, cache_metadata: easydel.layers.caching.transformer.cache.TransformerMetadata | easydel.layers.caching.ragged_page.cache.RaggedPagesMetadata | None = None, output_attentions: bool = False, frequencies: jaxtyping.Float[Array, 'seq_len head_dim'] | None = None, alibi: jaxtyping.Float[Array, 'batch_or_1 heads qseq_len_or_1 kvseq_len_or_1'] | None = None)[source]#
Override to handle fused QKV projection efficiently with optional clipping.
- projection_mapping: ClassVar[dict[str, str]] = {'output_projection': 'out_proj', 'query_key_value_projection': 'Wqkv'}#
- class easydel.modules.dbrx.modeling_dbrx.DbrxBlock(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleSingle transformer block for DBRX models.
Integrates attention mechanisms with mixture of experts feedforward networks, using residual connections and normalization.
- class easydel.modules.dbrx.modeling_dbrx.DbrxExpertGLU(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleGated Linear Unit expert module for DBRX mixture of experts.
Implements a single expert network with gated activation for specialized processing in the MoE architecture.
- class easydel.modules.dbrx.modeling_dbrx.DbrxExperts(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleCollection of expert networks for DBRX mixture of experts.
Manages multiple expert networks that can be selected and combined based on routing decisions for conditional computation.
- class easydel.modules.dbrx.modeling_dbrx.DbrxFFN(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleFeedforward network with mixture of experts for DBRX models.
Combines router and expert networks to implement sparse MoE feedforward layers with conditional computation.
- class easydel.modules.dbrx.modeling_dbrx.DbrxForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
BaseCausalLMModule[DbrxModel,DbrxConfig]DBRX model with a Causal Language Modeling head.
- class easydel.modules.dbrx.modeling_dbrx.DbrxForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
BaseSequenceClassificationModule[DbrxModel,DbrxConfig]DBRX model with a Sequence Classification head.
- class easydel.modules.dbrx.modeling_dbrx.DbrxModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleBase DBRX Model outputting raw hidden-states.
This model is a Transformer-based model with a mixture of experts (MoE) architecture, implementing the DBRX architecture as described in the original paper.
The model uses specialized attention modules and a router-based MoE FFN layer.
- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- get_decoder() Module[source]#
Returns the decoder part of the model’s graph definition. For DbrxModel, this is the model itself.
- class easydel.modules.dbrx.modeling_dbrx.DbrxNormAttentionNorm(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleNormalization-Attention-Normalization module for DBRX models.
Implements a unique architecture pattern with normalization layers surrounding the attention mechanism for improved gradient flow.
- static kernel_init(key: Array, shape: Sequence[Union[int, Any]], dtype: Any | None = None, out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None) Array#
An initializer that returns a constant array full of ones.
The
keyargument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32)