# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import jax
import jax.numpy as jnp
from eformer import common_types
from eformer.escale import apply_logical_sharding
from ejkernel.types import MaskInfo
from flax import nnx as nn
from jax.ad_checkpoint import checkpoint_name
from jaxtyping import Array, Bool, Float, Int
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
DecoderLayerOutput,
SequenceClassifierOutput,
)
from easydel.infra.utils import ArrayParam, auto_remat, block_wise_ffn, get_dot_general_by_bits
from easydel.layers.attention_unified import UnifiedAttention
from easydel.layers.base_modules import BaseCausalLMModule, BaseSequenceClassificationModule
from easydel.layers.caching import (
RaggedPagesCache,
RaggedPagesCacheView,
RaggedPagesMetadata,
TransformerCache,
TransformerCacheView,
TransformerMetadata,
)
from easydel.layers.linear import ColumnParallelLinear, RowParallelLinear
from .cohere_configuration import CohereConfig as CohereConfig
[docs]def repeat_kv(
x: Float[Array, "batch seq_len num_kv_heads head_dim"], n_rep: int
) -> Float[Array, "batch seq_len num_heads head_dim"]:
"""Tile key/value heads to match the requested number of attention heads."""
bs, s, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
x = x[:, :, jnp.newaxis, :, :]
x = jnp.repeat(x, n_rep, axis=2)
return x.reshape(bs, s, n_kv_heads * n_rep, head_dim)
[docs]class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization for Cohere models.
Implements RMS normalization with learnable scale parameters,
providing training stability without mean centering.
"""
kernel_init = staticmethod(nn.initializers.ones)
def __init__(
self,
dim: int | tuple,
eps: float = 1e-6,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
do_t: bool = False,
rngs: nn.Rngs = None,
):
super().__init__()
if rngs is None:
rngs = nn.Rngs(0)
self.dim = dim
self.eps = eps
self.dtype = dtype
self.param_dtype = param_dtype
self.do_t = do_t
self.kernel = ArrayParam.bound(
shape=(self.dim,) if isinstance(self.dim, int) else self.dim,
dtype=self.param_dtype,
init_method="ones",
key=rngs.params(),
)
def _norm(self, x: jnp.ndarray) -> jnp.ndarray:
return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps)
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
if self.dtype in [
jnp.float8_e4m3b11fnuz,
jnp.float8_e4m3fn,
jnp.float8_e4m3fnuz,
jnp.float8_e5m2,
jnp.float8_e5m2fnuz,
]:
x = x.astype(jnp.float32)
else:
x = x.astype(jnp.promote_types(self.dtype, jnp.float32))
output = self._norm(x).astype(self.dtype)
weight = self.kernel.value.astype(self.dtype)
if self.do_t:
weight = weight.T
return output * weight
[docs]class CohereAttention(UnifiedAttention):
"""Multi-head attention layer with RoPE embeddings for Cohere models.
Inherits from UnifiedAttention with Cohere-specific customizations:
- Optional Q/K normalization (use_qk_norm)
- Custom RoPE configuration
"""
def __init__(
self,
config: CohereConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
layer_idx: int,
) -> None:
"""Initialize Cohere attention with optional Q/K normalization."""
super().__init__(
config,
dtype,
param_dtype,
precision,
rngs=rngs,
layer_idx=layer_idx,
attention_type="standard",
causal=True,
)
# Add Q/K normalization if configured
if config.use_qk_norm:
self.q_norm = RMSNorm(
dim=(self.head_dim, self.config.num_attention_heads),
eps=config.layer_norm_eps,
dtype=self.dtype,
param_dtype=self.param_dtype,
do_t=True,
)
self.k_norm = RMSNorm(
dim=(
self.head_dim,
self.config.num_key_value_heads,
),
eps=config.layer_norm_eps,
dtype=self.dtype,
param_dtype=self.param_dtype,
do_t=True,
)
def _create_rotary(self, config: CohereConfig, dtype: jnp.dtype):
"""Create Cohere-specific rotary embedding layer."""
return config.get_basic_rope(dtype, self.head_dim, self.head_dim, True)
def _postprocess_qkv(self, query_states, key_states, value_states):
"""Apply Q/K normalization if configured."""
if self.config.use_qk_norm:
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
return query_states, key_states, value_states
[docs]class CohereMLP(nn.Module):
"""Multi-Layer Perceptron module for Cohere models.
Implements feedforward network with configurable activation functions
and gated linear units for enhanced representation learning.
"""
def __init__(
self,
config: CohereConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
layer_idx: int,
):
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
column_parallel_linear = partial(
ColumnParallelLinear,
dtype=dtype,
param_dtype=param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
precision=self.precision,
rngs=rngs,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
row_parallel_linear = partial(
RowParallelLinear,
dtype=dtype,
param_dtype=param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
precision=self.precision,
rngs=rngs,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
self.gate_proj = column_parallel_linear(config.hidden_size, config.intermediate_size)
self.down_proj = row_parallel_linear(config.intermediate_size, config.hidden_size)
self.up_proj = column_parallel_linear(config.hidden_size, config.intermediate_size)
def __call__(
self, hidden_states: Float[Array, "batch seq_len hidden_dim"]
) -> Float[Array, "batch seq_len hidden_dim"]:
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
gate = jax.nn.silu(checkpoint_name(self.gate_proj(hidden_states), name="mlp_gate"))
up = checkpoint_name(self.up_proj(hidden_states), name="mlp_up")
hidden_states = checkpoint_name(self.down_proj(gate * up), name="mlp_down")
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
return hidden_states
[docs]class CohereBlock(nn.Module):
"""Single transformer block for Cohere models.
Combines self-attention, feedforward networks, and layer normalization
with residual connections to form a complete transformer layer.
"""
def __init__(
self,
config: CohereConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
layer_idx: int,
) -> None:
super().__init__()
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.rngs = rngs
attn_block = CohereAttention
mlp_block = CohereMLP
attn_block, mlp_block = auto_remat(
attn_block,
mlp_block,
policy=config.gradient_checkpointing,
save_names=config.gradient_checkpointing_targets,
exclude_names=config.gradient_checkpointing_targets,
)
self.self_attn = attn_block(
config,
layer_idx=layer_idx,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.mlp = mlp_block(
config,
layer_idx=layer_idx,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.input_layernorm = RMSNorm(
self.config.hidden_size,
eps=self.config.layer_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
def __call__(
self,
hidden_states: Float[Array, "batch seq_len hidden_dim"],
mask_info: MaskInfo | None,
position_ids: Int[Array, "batch seq_len"],
mode: common_types.RUNTIME_MODE_TYPES, # type:ignore
cache_view: TransformerCacheView | RaggedPagesCacheView | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
output_attentions: bool = False,
frequencies: Float[Array, "seq_len head_dim"] | None = None,
):
"""
Forward pass of the module block.
Args:
hidden_states (chex.Array): Input hidden states.
attention_mask (chex.Array): Mask to apply on the attention scores.
position_ids (chex.Array): Position indices for the tokens.
causal_mask (chex.Array): Causal mask for ensuring autoregressive behavior.
segment_ids (tp.Optional[chex.Array]): Segment IDs for segment-based attention (optional).
deterministic (bool): If True, disables dropout for deterministic behavior.
init_cache (bool): If True, initializes cache for caching keys and values.
output_attentions (bool): If True, outputs attention weights alongside the hidden states.
fcm_mask (tp.Optional[chex.Array]): fcm mask to be combined with attn mask and causal mask.
Returns:
tp.Tuple[chex.Array, chex.Array]: A tuple containing the attention output and the attention weights.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
attn_outputs = self.self_attn(
hidden_states,
mask_info,
position_ids,
mode,
cache_view,
cache_metadata,
output_attentions,
frequencies,
)
feed_forward_input = hidden_states
if self.config.use_scan_mlp:
feed_forward_hidden_states = block_wise_ffn(
self.mlp,
feed_forward_input,
self.config.scan_mlp_chunk_size,
)
else:
feed_forward_hidden_states = self.mlp(feed_forward_input)
hidden_states = attn_outputs.attention_output + feed_forward_hidden_states + residual
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
return DecoderLayerOutput(
hidden_states=hidden_states,
attention_weight=attn_outputs.attention_weight,
router_logits=None,
gate_loss=None,
cache_view=attn_outputs.cache_view,
)
[docs]@register_module(TaskType.BASE_MODULE, config=CohereConfig, model_type="cohere")
class CohereModel(EasyDeLBaseModule):
"""Decoder-only Cohere transformer assembling embeddings, blocks, and final norm."""
def __init__(
self,
config: CohereConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
embed_block = auto_remat(
nn.Embed,
policy=config.gradient_checkpointing,
save_names=config.gradient_checkpointing_targets,
exclude_names=config.gradient_checkpointing_targets,
)
self.embed_tokens = embed_block(
config.vocab_size,
config.hidden_size,
embedding_init=nn.initializers.normal(stddev=config.initializer_range),
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.layers = [
CohereBlock(
config=config,
layer_idx=i,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
for i in range(config.num_hidden_layers)
]
self.norm = RMSNorm(
self.config.hidden_size,
eps=self.config.layer_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
)
def __call__(
self,
input_ids: Int[Array, "batch seq_len"] | None = None,
inputs_embeds: Float[Array, "batch seq_len hidden_dim"] | None = None,
attention_mask: Bool[Array, "batch seq_len"] | None = None,
mask_info: MaskInfo | None = None,
position_ids: Int[Array, "batch seq_len"] | None = None,
mode: common_types.RUNTIME_MODE_TYPES | None = None, # type:ignore
past_key_values: TransformerCache | RaggedPagesCache | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
) -> BaseModelOutput:
"""Forward pass through the core Cohere model.
Args:
input_ids (Optional[chex.Array]): Input token IDs.
inputs_embeds (Optional[chex.Array]): Input embeddings (alternative to input_ids).
attention_mask (Optional[chex.Array]): Attention mask.
position_ids (Optional[chex.Array]): Position IDs.
segment_ids (Optional[chex.Array]): Segment IDs.
output_attentions (Optional[bool]): Whether to output attentions.
output_hidden_states (Optional[bool]): Whether to output hidden states.
past_key_values (Optional[TransformerCache | RaggedPagesCache]): KV cache.
cache_metadata (Optional[TransformerMetadata | RaggedPagesMetadata]): Cache metadata.
Returns:
Union[BaseModelOutput, Tuple]: Model output.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids.astype("i4"))
sequence_length = inputs_embeds.shape[1]
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
assert sequence_length <= self.config.max_position_embeddings, (
f"Maximum Position Embedding Reached ! "
f"(Excepted <= {self.config.max_position_embeddings} got {sequence_length})"
)
mask_info = MaskInfo.dynamic_init(
mask_info=mask_info,
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
if position_ids is None:
position_ids = mask_info.q_position_ids
hidden_states = inputs_embeds
if mode is None:
mode = (
common_types.MODE_DECODE
if sequence_length == 1 and past_key_values is not None
else common_types.MODE_TRAIN
)
if past_key_values is None:
past_key_values = TransformerCache.init_empty(len(self.layers))
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
for idx, block in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = block(
hidden_states=hidden_states,
mask_info=mask_info,
position_ids=position_ids,
mode=mode,
cache_view=past_key_values.views[idx],
cache_metadata=cache_metadata,
output_attentions=output_attentions,
frequencies=self.frequencies,
)
hidden_states = layer_outputs.hidden_states
if output_attentions:
all_attentions += (layer_outputs.attention_weight,)
past_key_values[idx] = layer_outputs.cache_view
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states = (*all_hidden_states, hidden_states)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
past_key_values=past_key_values,
)
[docs] def get_encoder(self) -> nn.Module:
"""
Returns the encoder part of the model's graph definition.
For CohereModel (decoder-only), this is not applicable.
"""
raise NotImplementedError("CohereModel is a decoder-only model and does not have a separate encoder.")
[docs] def get_decoder(self) -> nn.Module:
"""
Returns the decoder part of the model's graph definition.
For CohereModel, this is the model itself.
"""
return self
[docs] def get_lm_head(self) -> nn.Module:
"""
Returns the language model head of the module.
CohereModel does not include the lm_head.
"""
raise NotImplementedError("CohereModel does not include the language model head. See CohereForCausalLM.")
[docs] def get_embedding(self) -> nn.Module:
"""
Returns the embedding layer of the module.
"""
return self.embed_tokens
[docs]@register_module(TaskType.CAUSAL_LM, config=CohereConfig, model_type="cohere")
class CohereForCausalLM(BaseCausalLMModule[CohereModel, CohereConfig]):
"""Cohere model with a Causal Language Modeling head."""
_task_type = TaskType.CAUSAL_LM
_model_type = "cohere"
_config_class = CohereConfig
def __init__(
self,
config: CohereConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
base_model_class=CohereModel,
base_model_name="model",
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
lm_head_bias=False,
)
self.logit_scale = self.config.logit_scale
def __call__(
self,
input_ids: Int[Array, "batch seq_len"] | None = None,
inputs_embeds: Float[Array, "batch seq_len hidden_dim"] | None = None,
attention_mask: Bool[Array, "batch seq_len"] | None = None,
mask_info: MaskInfo | None = None,
position_ids: Int[Array, "batch seq_len"] | None = None,
mode: common_types.RUNTIME_MODE_TYPES | None = None, # type:ignore
past_key_values: TransformerCache | RaggedPagesCache | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
apply_lm_head: bool = True,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
) -> CausalLMOutput:
"""
Forward pass through the Cohere model for Causal Language Modeling.
Args:
input_ids (Optional[chex.Array]): Input tensor containing token IDs.
inputs_embeds (Optional[chex.Array]): Embedded input tensor (alternative to input_ids).
attention_mask (Optional[chex.Array]): Mask for attention.
position_ids (Optional[chex.Array]): Positional indices.
segment_ids (Optional[chex.Array]): Segment IDs for different input parts.
output_attentions (Optional[bool]): If True, output attention weights.
output_hidden_states (Optional[bool]): If True, output hidden states.
past_key_values (Optional[TransformerCache | RaggedPagesCache]): KV cache for faster generation.
cache_metadata (Optional[TransformerMetadata | RaggedPagesMetadata]): Metadata for paged attention.
Returns:
Union[CausalLMOutput, Tuple]: Model output, including logits.
"""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
mask_info=mask_info,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
mode=mode,
past_key_values=past_key_values,
cache_metadata=cache_metadata,
inputs_embeds=inputs_embeds,
)
hidden_states = outputs.last_hidden_state
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
lm_logits = None
if apply_lm_head:
lm_logits = self.apply_lm_head(hidden_states)
lm_logits = lm_logits * self.logit_scale
return CausalLMOutput(
logits=lm_logits,
hidden_states=outputs.hidden_states,
last_hidden_state=outputs.last_hidden_state,
attentions=outputs.attentions,
past_key_values=outputs.past_key_values,
)
[docs] def get_encoder(self) -> nn.Module:
"""
Returns the encoder part of the model's graph definition.
For CohereForCausalLM (decoder-only), this is not applicable.
"""
raise NotImplementedError("CohereForCausalLM is a decoder-only model and does not have a separate encoder.")
[docs] def get_decoder(self) -> nn.Module:
"""
Returns the decoder part of the model's graph definition.
For CohereForCausalLM, this is the underlying CohereModel.
"""
return self.model.get_decoder() # self.model is the CohereModel instance
[docs] def get_lm_head(self) -> nn.Module:
"""
Returns the language model head of the module.
"""
return self.lm_head
[docs] def get_embedding(self) -> nn.Module:
"""
Returns the embedding layer of the module.
"""
# Access the embedding layer through the decoder (CohereModel)
return self.model.get_embedding() # Leverages CohereModel's get_embedding
[docs]@register_module(TaskType.SEQUENCE_CLASSIFICATION, config=CohereConfig, model_type="cohere")
class CohereForSequenceClassification(BaseSequenceClassificationModule[CohereModel, CohereConfig]):
"""Cohere model for sequence classification."""
_task_type = TaskType.SEQUENCE_CLASSIFICATION
_model_type = "cohere"
_config_class = CohereConfig
def __init__(
self,
config: CohereConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
base_model_class=CohereModel,
base_model_name="model",
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
classifier_name="score",
classifier_bias=False,
)
def __call__(
self,
input_ids: Int[Array, "batch seq_len"] | None = None,
inputs_embeds: Float[Array, "batch seq_len hidden_dim"] | None = None,
attention_mask: Bool[Array, "batch seq_len"] | None = None,
mask_info: MaskInfo | None = None,
position_ids: Int[Array, "batch seq_len"] | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
mode: common_types.RUNTIME_MODE_TYPES | None = None, # type:ignore
past_key_values: TransformerCache | RaggedPagesCache | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
) -> SequenceClassifierOutput:
"""
Forward pass for sequence classification.
Args:
input_ids (Optional[chex.Array]): Input token IDs.
inputs_embeds (Optional[chex.Array]): Input embeddings (alternative to input_ids).
attention_mask (Optional[chex.Array]): Attention mask.
position_ids (Optional[chex.Array]): Position IDs.
segment_ids (Optional[chex.Array]): Segment IDs.
output_attentions (Optional[bool]): Whether to output attentions.
output_hidden_states (Optional[bool]): Whether to output hidden states.
past_key_values (Optional[TransformerCache | RaggedPagesCache]): KV cache.
cache_metadata (Optional[TransformerMetadata | RaggedPagesMetadata]): Cache metadata.
Returns:
Union[SequenceClassifierOutput, Tuple]: Classification output (logits and optional hidden states/attentions).
"""
transformer_outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
mask_info=mask_info,
position_ids=position_ids,
mode=mode,
past_key_values=past_key_values,
cache_metadata=cache_metadata,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
inputs_embeds=inputs_embeds,
)
hidden_states = transformer_outputs.last_hidden_state
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = jnp.argmax(jnp.equal(input_ids, self.config.pad_token_id).astype("i4"), -1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
else:
sequence_lengths = -1
pooled_logits = logits[jnp.arange(batch_size), sequence_lengths]
return SequenceClassifierOutput(
logits=pooled_logits,
past_key_values=past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
[docs] def get_encoder(self) -> nn.Module:
"""
Returns the encoder part of the model's graph definition.
For CohereForSequenceClassification (decoder-only), this is not applicable.
"""
raise NotImplementedError(
"CohereForSequenceClassification is a decoder-only model and does not have a separate encoder."
)
[docs] def get_decoder(self) -> nn.Module:
"""
Returns the decoder part of the model's graph definition.
For CohereForSequenceClassification, this is the underlying CohereModel.
"""
return self.model # self.model is the CohereModel instance
[docs] def get_lm_head(self) -> nn.Module:
"""
Returns the language model head of the module.
CohereForSequenceClassification uses a classification head instead.
"""
raise NotImplementedError(
"CohereForSequenceClassification uses a classification head (self.score), not an lm_head."
)
[docs] def get_embedding(self) -> nn.Module:
"""
Returns the embedding layer of the module.
"""
# Access the embedding layer through the decoder (CohereModel)
return self.model.get_embedding() # Leverages CohereModel's get_embedding