easydel.modules.cohere2.modeling_cohere2#
- class easydel.modules.cohere2.modeling_cohere2.Cohere2Attention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionCohere2 Attention with layer-specific sliding window and conditional RoPE.
Inherits from UnifiedAttention with Cohere2-specific customizations: - Layer-specific sliding window (only applies to sliding_attention layers) - Conditional RoPE application (only when sliding window is enabled)
- class easydel.modules.cohere2.modeling_cohere2.Cohere2Block(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleCohere v2 transformer block combining attention and MLP.
- class easydel.modules.cohere2.modeling_cohere2.Cohere2ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
BaseCausalLMModule[Cohere2Model,Cohere2Config]Cohere2 model with a Causal Language Modeling head.
- apply_lm_head(hidden_states: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#
Applies the language model head to the hidden states.
- Parameters
hidden_states (chex.Array) – The last hidden states from the model.
- Returns
The logits after applying the language model head.
- Return type
chex.Array
- get_decoder() Module[source]#
Returns the decoder part of the model’s graph definition. For Cohere2ForCausalLM, this is the underlying Cohere2Model.
- class easydel.modules.cohere2.modeling_cohere2.Cohere2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
BaseSequenceClassificationModule[Cohere2Model,Cohere2Config]Cohere2 model for sequence classification.
- get_decoder() Module[source]#
Returns the decoder part of the model’s graph definition. For Cohere2ForSequenceClassification, this is the underlying Cohere2Model.
- class easydel.modules.cohere2.modeling_cohere2.Cohere2LayerNorm(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleCohere Layer Normalization.
- dim#
The dimension(s) to normalize over.
- Type
Union[int, tuple]
- eps#
A small epsilon value to prevent division by zero.
- Type
float
- dtype#
The data type for computation.
- Type
jnp.dtype
- param_dtype#
The data type for the parameters.
- Type
jnp.dtype
- rngs#
Random number generators.
- Type
Optional[nn.Rngs]
- static kernel_init(key: Array, shape: Sequence[Union[int, Any]], dtype: Any | None = None, out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None) Array#
An initializer that returns a constant array full of ones.
The
keyargument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32)
- class easydel.modules.cohere2.modeling_cohere2.Cohere2MLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleFeed-forward network used in Cohere v2 decoder layers.
- class easydel.modules.cohere2.modeling_cohere2.Cohere2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleDecoder-only Cohere v2 model with embeddings, blocks, and final norm.
- get_decoder() Module[source]#
Returns the decoder part of the model’s graph definition. For Cohere2Model, this is the model itself.