easydel.modules.cohere2.modeling_cohere2_flax#

class easydel.modules.cohere2.modeling_cohere2_flax.Cohere2Attention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

Cohere2 Attention module, incorporating features like RoPE and sliding window attention.

config#

Configuration object.

Type

Cohere2Config

layer_idx#

The index of the current layer.

Type

int

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.cohere2.modeling_cohere2_flax.Cohere2Block(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.cohere2.modeling_cohere2_flax.Cohere2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.modules.cohere2.modeling_cohere2_flax.Cohere2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.modules.cohere2.modeling_cohere2_flax.Cohere2LayerNorm(*args: Any, **kwargs: Any)[source]#

Bases: Module

Cohere Layer Normalization.

dim#

The dimension(s) to normalize over.

Type

Union[int, tuple]

eps#

A small epsilon value to prevent division by zero.

Type

float

dtype#

The data type for computation.

Type

jnp.dtype

param_dtype#

The data type for the parameters.

Type

jnp.dtype

rngs#

Random number generators.

Type

Optional[nn.Rngs]

class easydel.modules.cohere2.modeling_cohere2_flax.Cohere2MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.cohere2.modeling_cohere2_flax.Cohere2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule