easydel.modules.cohere2.modeling_cohere2#

class easydel.modules.cohere2.modeling_cohere2.Cohere2Attention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

Cohere2 Attention with layer-specific sliding window and conditional RoPE.

Inherits from UnifiedAttention with Cohere2-specific customizations: - Layer-specific sliding window (only applies to sliding_attention layers) - Conditional RoPE application (only when sliding window is enabled)

class easydel.modules.cohere2.modeling_cohere2.Cohere2Block(*args: Any, **kwargs: Any)[source]#

Bases: Module

Cohere v2 transformer block combining attention and MLP.

class easydel.modules.cohere2.modeling_cohere2.Cohere2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[Cohere2Model, Cohere2Config]

Cohere2 model with a Causal Language Modeling head.

apply_lm_head(hidden_states: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#

Applies the language model head to the hidden states.

Parameters

hidden_states (chex.Array) – The last hidden states from the model.

Returns

The logits after applying the language model head.

Return type

chex.Array

get_decoder() Module[source]#

Returns the decoder part of the model’s graph definition. For Cohere2ForCausalLM, this is the underlying Cohere2Model.

get_embedding() Module[source]#

Returns the embedding layer of the module.

get_encoder() Module[source]#

Returns the encoder part of the model’s graph definition. For Cohere2ForCausalLM (decoder-only), this is not applicable.

get_lm_head() Module[source]#

Returns the language model head of the module.

class easydel.modules.cohere2.modeling_cohere2.Cohere2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: BaseSequenceClassificationModule[Cohere2Model, Cohere2Config]

Cohere2 model for sequence classification.

get_decoder() Module[source]#

Returns the decoder part of the model’s graph definition. For Cohere2ForSequenceClassification, this is the underlying Cohere2Model.

get_embedding() Module[source]#

Returns the embedding layer of the module.

get_encoder() Module[source]#

Returns the encoder part of the model’s graph definition. For Cohere2ForSequenceClassification (decoder-only), this is not applicable.

get_lm_head() Module[source]#

Returns the language model head of the module. Cohere2ForSequenceClassification uses a classification head instead.

class easydel.modules.cohere2.modeling_cohere2.Cohere2LayerNorm(*args: Any, **kwargs: Any)[source]#

Bases: Module

Cohere Layer Normalization.

dim#

The dimension(s) to normalize over.

Type

Union[int, tuple]

eps#

A small epsilon value to prevent division by zero.

Type

float

dtype#

The data type for computation.

Type

jnp.dtype

param_dtype#

The data type for the parameters.

Type

jnp.dtype

rngs#

Random number generators.

Type

Optional[nn.Rngs]

static kernel_init(key: Array, shape: Sequence[Union[int, Any]], dtype: Any | None = None, out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None) Array#

An initializer that returns a constant array full of ones.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)
class easydel.modules.cohere2.modeling_cohere2.Cohere2MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Feed-forward network used in Cohere v2 decoder layers.

class easydel.modules.cohere2.modeling_cohere2.Cohere2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Decoder-only Cohere v2 model with embeddings, blocks, and final norm.

get_decoder() Module[source]#

Returns the decoder part of the model’s graph definition. For Cohere2Model, this is the model itself.

get_embedding() Module[source]#

Returns the embedding layer of the module.

get_encoder() Module[source]#

Returns the encoder part of the model’s graph definition. For Cohere2Model (decoder-only), this is not applicable.

get_lm_head() Module[source]#

Returns the language model head of the module. Cohere2Model does not include the lm_head.