easydel.modules.cohere.modeling_cohere#
- class easydel.modules.cohere.modeling_cohere.CohereAttention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionMulti-head attention layer with RoPE embeddings for Cohere models.
Inherits from UnifiedAttention with Cohere-specific customizations: - Optional Q/K normalization (use_qk_norm) - Custom RoPE configuration
- class easydel.modules.cohere.modeling_cohere.CohereBlock(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleSingle transformer block for Cohere models.
Combines self-attention, feedforward networks, and layer normalization with residual connections to form a complete transformer layer.
- class easydel.modules.cohere.modeling_cohere.CohereForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
BaseCausalLMModule[CohereModel,CohereConfig]Cohere model with a Causal Language Modeling head.
- get_decoder() Module[source]#
Returns the decoder part of the model’s graph definition. For CohereForCausalLM, this is the underlying CohereModel.
- class easydel.modules.cohere.modeling_cohere.CohereForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
BaseSequenceClassificationModule[CohereModel,CohereConfig]Cohere model for sequence classification.
- get_decoder() Module[source]#
Returns the decoder part of the model’s graph definition. For CohereForSequenceClassification, this is the underlying CohereModel.
- class easydel.modules.cohere.modeling_cohere.CohereMLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMulti-Layer Perceptron module for Cohere models.
Implements feedforward network with configurable activation functions and gated linear units for enhanced representation learning.
- class easydel.modules.cohere.modeling_cohere.CohereModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleDecoder-only Cohere transformer assembling embeddings, blocks, and final norm.
- get_decoder() Module[source]#
Returns the decoder part of the model’s graph definition. For CohereModel, this is the model itself.
- class easydel.modules.cohere.modeling_cohere.RMSNorm(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleRoot Mean Square Layer Normalization for Cohere models.
Implements RMS normalization with learnable scale parameters, providing training stability without mean centering.
- static kernel_init(key: Array, shape: Sequence[Union[int, Any]], dtype: Any | None = None, out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None) Array#
An initializer that returns a constant array full of ones.
The
keyargument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32)