Source code for easydel.modules.gpt_oss.modeling_gpt_oss

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import typing

import chex
import jax
import jax.numpy as jnp
from eformer import common_types
from eformer.escale import apply_logical_sharding
from ejkernel.types import MaskInfo
from flax import nnx as nn
from jax.ad_checkpoint import checkpoint_name
from jaxtyping import Array, Bool, Float, Int

from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.loss_utils import auxiliary_load_balancing_loss_func
from easydel.infra.modeling_outputs import (
    DecoderLayerOutput,
    MoeCausalLMOutput,
    MoeModelOutput,
    SequenceClassifierOutput,
)
from easydel.infra.utils import ACT2FN, ArrayParam, auto_remat
from easydel.layers.attention_unified import UnifiedAttention
from easydel.layers.base_modules import BaseCausalLMModule
from easydel.layers.caching import (
    RaggedPagesCache,
    RaggedPagesCacheView,
    RaggedPagesMetadata,
    TransformerCache,
    TransformerCacheView,
    TransformerMetadata,
)
from easydel.layers.linear import ColumnParallelLinear
from easydel.layers.moe.linear import ColumnParallelMoELinear, RowParallelMoELinear
from easydel.layers.moe.moe import BaseMoeModule
from easydel.layers.moe.utils import MoeLoadBalancingStrategy, MoeRoutingStrategy
from easydel.layers.norms import RMSNorm

from .gpt_oss_configuration import GptOssConfig


[docs]class GptOssRMSNorm(RMSNorm): ...
[docs]class GptOssExperts(nn.Module): """Grouped expert feed-forward network used inside GPT-OSS MoE layers.""" reform_param: typing.ClassVar = { "gate_up_proj$": { "splits": [ {"name": "gate_proj.kernel", "spliter": lambda x: x[..., 0::2]}, {"name": "up_proj.kernel", "spliter": lambda x: x[..., 1::2]}, ], "inverse_spliter": lambda torch, gate, up: torch.stack((gate, up), dim=-1).flatten(-2), }, "gate_up_proj_bias$": { "splits": [ {"name": "gate_proj.bias", "spliter": lambda x: x[..., 0::2]}, {"name": "up_proj.bias", "spliter": lambda x: x[..., 1::2]}, ], "inverse_spliter": lambda torch, gate, up: torch.stack((gate, up), dim=-1).flatten(-2), }, "down_proj$": { "splits": [ {"name": "down_proj.kernel", "spliter": lambda x: x}, ], "inverse_spliter": lambda x: x, }, "down_proj_bias$": { "splits": [ {"name": "down_proj.bias", "spliter": lambda x: x}, ], "inverse_spliter": lambda x: x, }, } def __init__( self, config: GptOssConfig, dtype: jnp.dtype = jnp.bfloat16, param_dtype: jnp.dtype = jnp.bfloat16, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.config = config self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.intermediate_size = config.intermediate_size self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size self.gate_proj = ColumnParallelMoELinear( num_experts=config.num_local_experts, in_features=config.hidden_size, out_features=config.intermediate_size, rngs=rngs, kernel_init=nn.initializers.normal(), use_bias=True, partition_manager=config.partition_manager, use_expert_tensor_mode=config.use_expert_tensor_mode, dtype=dtype, param_dtype=param_dtype, ) self.down_proj = RowParallelMoELinear( num_experts=config.num_local_experts, in_features=config.intermediate_size, out_features=config.hidden_size, rngs=rngs, use_bias=True, kernel_init=nn.initializers.normal(), partition_manager=config.partition_manager, use_expert_tensor_mode=config.use_expert_tensor_mode, dtype=dtype, param_dtype=param_dtype, ) self.up_proj = ColumnParallelMoELinear( num_experts=config.num_local_experts, in_features=config.hidden_size, out_features=config.intermediate_size, rngs=rngs, use_bias=True, kernel_init=nn.initializers.normal(), partition_manager=config.partition_manager, use_expert_tensor_mode=config.use_expert_tensor_mode, dtype=dtype, param_dtype=param_dtype, ) self.alpha = 1.702 self.act_fn = ACT2FN[config.hidden_act] def __call__( self, hidden_states: Float[Array, "batch seq_len hidden_dim"], group_sizes: chex.Array, sorted_experts: chex.Array | None = None, ) -> chex.Array: """Forward pass through MoE MLP.""" w0 = self.gate_proj(hidden_states, group_sizes, sorted_experts) w1 = self.up_proj(hidden_states, group_sizes, sorted_experts) w0 = jnp.clip(w0, min=None, max=7.0) w1 = jnp.clip(w1, min=-7.0, max=7.0) glu = w0 * jax.nn.sigmoid(w0 * self.alpha) intermediate = (w1 + 1.0) * glu return self.down_proj(intermediate, group_sizes, sorted_experts)
[docs]class GptOssMLP(BaseMoeModule): """Mixture-of-experts MLP combining the router and shared experts.""" def __init__( self, config: GptOssConfig, dtype: jnp.dtype = jnp.bfloat16, param_dtype: jnp.dtype = jnp.bfloat16, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, n_routed_experts=config.num_local_experts, num_experts_per_tok=config.num_experts_per_tok, hidden_size=config.hidden_size, lbl_coef=None, rzl_coef=None, routing_strategy=MoeRoutingStrategy.TOP_K, load_balancing_strategy=MoeLoadBalancingStrategy.STANDARD, ) self.router = ColumnParallelLinear( config.hidden_size, config.num_local_experts, use_bias=True, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, kernel_init=nn.initializers.normal(config.initializer_range), ) self.experts = GptOssExperts( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def _scatter_topk_probs(logits: jax.Array) -> jax.Array: top_vals, top_idx = jax.lax.top_k(logits, k=self.num_experts_per_tok) top_probs = jax.nn.softmax(top_vals, axis=-1) out = jnp.zeros_like(logits) row_idx = jnp.arange(logits.shape[0])[:, None] return out.at[row_idx, top_idx].set(top_probs) def _softmax_topk_weights(weights: jax.Array) -> jax.Array: return jax.nn.softmax(weights, axis=-1) self.moe_hooks = self.moe_hooks.replace( after_gate=_scatter_topk_probs, refine_weights_hook=_softmax_topk_weights, ) def __call__(self, hidden_states, training=False, layer_idx=None): del training def ffn_activation(w0, w1): w0 = jnp.clip(w0, min=None, max=7.0) w1 = jnp.clip(w1, min=-7.0, max=7.0) glu = w0 * jax.nn.sigmoid(w0 * self.experts.alpha) return (w1 + 1.0) * glu out, router_logits = self.moe_call( hidden_state=hidden_states, gate_layer=self.router, expert_layer=self.experts, wi_kernel=self.experts.gate_proj.kernel.value, wu_kernel=self.experts.up_proj.kernel.value, wd_kernel=self.experts.down_proj.kernel.value, wi_bias=self.experts.gate_proj.bias.value, wu_bias=self.experts.up_proj.bias.value, wd_bias=self.experts.down_proj.bias.value, act_fn=self.experts.act_fn, ffn_activation=ffn_activation, layer_idx=layer_idx, ) return checkpoint_name(out, "moe_expert_output"), checkpoint_name(router_logits, "moe_router_logits")
[docs]class GptOssAttention(UnifiedAttention): """GPT-OSS Attention with sink tokens support. Inherits from UnifiedAttention. Supports layer-specific sliding windows and sink tokens for improved attention. """ def __init__( self, config: GptOssConfig, dtype: jnp.dtype = jnp.bfloat16, param_dtype: jnp.dtype = jnp.bfloat16, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, layer_idx: int, ): """Initialize GPT-OSS attention with sink tokens.""" is_sliding = config.layer_types[layer_idx] == "sliding_attention" sliding_window = None if is_sliding: sliding_window = config.sliding_window super().__init__( config, dtype, param_dtype, precision, rngs=rngs, layer_idx=layer_idx, attention_type="standard", causal=True, sliding_window=sliding_window, ) # Sink tokens for improved attention self.sinks = ArrayParam.bound( shape=(config.num_attention_heads,), dtype=param_dtype, init_method="normal", init_kwargs={"stddev": config.initializer_range}, key=rngs.param(), )
[docs]class GptOssDecoderLayer(nn.Module): """Single GPT-OSS decoder block with attention and expert MLP.""" def __init__( self, config: GptOssConfig, dtype: jnp.dtype = jnp.bfloat16, param_dtype: jnp.dtype = jnp.bfloat16, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, layer_idx: int, ): self.config = config self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.layer_idx = layer_idx attn_block = GptOssAttention mlp_block = GptOssMLP attn_block, mlp_block = auto_remat( attn_block, mlp_block, policy=config.gradient_checkpointing, save_names=config.gradient_checkpointing_targets, exclude_names=config.gradient_checkpointing_targets, ) self.self_attn = attn_block( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, layer_idx=layer_idx, ) self.mlp = mlp_block( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.input_layernorm = GptOssRMSNorm( dim=config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.post_attention_layernorm = GptOssRMSNorm( dim=config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.attention_type = config.layer_types[layer_idx] def __call__( self, hidden_states: Float[Array, "batch seq_len hidden_dim"], mask_info: MaskInfo, position_ids: Int[Array, "batch seq_len"], mode: common_types.RUNTIME_MODE_TYPES, # type:ignore cache_view: TransformerCacheView | RaggedPagesCacheView | None = None, cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None, output_attentions: bool = False, output_router_logits: bool = False, frequencies: Float[Array, "seq_len head_dim"] | None = None, ): attn_outputs = self.self_attn( self.input_layernorm(hidden_states), mask_info, position_ids, mode, cache_view, cache_metadata, output_attentions, frequencies, ) hidden_states = hidden_states + attn_outputs.attention_output feed_forward_input = self.post_attention_layernorm(hidden_states) feed_forward_hidden_states, router_logits = self.mlp(feed_forward_input) hidden_states = hidden_states + feed_forward_hidden_states hidden_states = apply_logical_sharding( hidden_states, dynamic_axes=common_types.HiddenStateSharding, partition_manager=self.config.partition_manager, ) return DecoderLayerOutput( hidden_states=hidden_states, attention_weight=attn_outputs.attention_weight, cache_view=attn_outputs.cache_view, router_logits=router_logits, )
[docs]@register_module(TaskType.BASE_MODULE, config=GptOssConfig, model_type="gpt_oss") class GptOssModel(EasyDeLBaseModule): """The base GptOss model transformer. This class represents the core transformer architecture of the GptOss model, consisting of an embedding layer, multiple GptOssDecoderLayer layers (with sparse MoE), and a final layer normalization. Attributes: config (GptOssConfig): Configuration object for the model. dtype (jnp.dtype): Data type for computation. param_dtype (jnp.dtype): Data type for parameters. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. rngs (nn.Rngs): Random number generators. embed_tokens (nn.Embed): Embedding layer for input tokens. layers (tp.List[GptOssDecoderLayer]): List of decoder layers. norm (RMSNorm): Final layer normalization. gradient_checkpointing (EasyDeLGradientCheckPointers): Gradient checkpointing configuration. """ def __init__( self, config: GptOssConfig, dtype: jnp.dtype = jnp.bfloat16, param_dtype: jnp.dtype = jnp.bfloat16, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): """Initializes the GptOssModel. Args: config (GptOssConfig): The configuration object for the GptOss model. dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32. param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None. rngs (nn.Rngs): Random number generators. """ super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) embed_block = auto_remat( nn.Embed, policy=config.gradient_checkpointing, save_names=config.gradient_checkpointing_targets, exclude_names=config.gradient_checkpointing_targets, ) self.embed_tokens = embed_block( config.vocab_size, config.hidden_size, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.layers = [ GptOssDecoderLayer( config=config, layer_idx=layer_idx, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) for layer_idx in range(config.num_hidden_layers) ] self.norm = GptOssRMSNorm( dim=config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) def __call__( self, input_ids: Int[Array, "batch seq_len"] | None = None, inputs_embeds: Float[Array, "batch seq_len hidden_dim"] | None = None, attention_mask: Bool[Array, "batch seq_len"] | None = None, mask_info: MaskInfo | None = None, position_ids: Int[Array, "batch seq_len"] | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, output_router_logits: bool | None = None, mode: common_types.RUNTIME_MODE_TYPES | None = None, # type:ignore past_key_values: TransformerCache | RaggedPagesCache | None = None, cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None, ) -> MoeModelOutput: """Forward pass of the GptOssModel. Args: input_ids (tp.Optional[chex.Array]): Input token IDs. Shape: (batch_size, sequence_length). inputs_embeds (tp.Optional[chex.Array]): Input embeddings. Either `input_ids` or `inputs_embeds` must be provided. attention_mask (tp.Optional[chex.Array]): Mask to avoid performing attention on padding token indices. Shape: (batch_size, sequence_length). position_ids (tp.Optional[chex.Array]): Position indices for the tokens. Shape: (batch_size, sequence_length). segment_ids (tp.Optional[chex.Array]): Segment IDs (unused). output_attentions (tp.Optional[bool]): Whether to return attention weights. Defaults to `config.output_attentions`. output_hidden_states (tp.Optional[bool]): Whether to return hidden states for all layers. Defaults to `config.output_hidden_states`. output_router_logits (tp.Optional[bool]): Whether to return router logits from the MoE layers. Defaults to `config.output_router_logits`. past_key_values (tp.Optional[TransformerCache | RaggedPagesCache]): Precomputed key/value states for attention. cache_metadata (tp.Optional[TransformerMetadata | RaggedPagesMetadata]): Metadata for paged attention. Returns: MoeModelOutput: The model's output. returns a `MoeModelOutput` object containing `last_hidden_state`, `hidden_states` (optional), `attentions` (optional), and `router_logits` (optional). Raises: ValueError: If neither `input_ids` nor `inputs_embeds` is provided. """ if output_router_logits is None: output_router_logits = self.config.output_router_logits output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids.astype("i4")) sequence_length = inputs_embeds.shape[1] assert sequence_length <= self.config.max_position_embeddings, ( f"Maximum Position Embedding Reached ! " f"(Excepted <= {self.config.max_position_embeddings} got {sequence_length})" ) mask_info = MaskInfo.dynamic_init( mask_info=mask_info, input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, ) if position_ids is None: position_ids = mask_info.q_position_ids hidden_states = inputs_embeds if mode is None: mode = ( common_types.MODE_DECODE if sequence_length == 1 and past_key_values is not None else common_types.MODE_TRAIN ) if past_key_values is None: past_key_values = TransformerCache.init_empty(len(self.layers)) hidden_states = apply_logical_sharding( hidden_states, dynamic_axes=common_types.HiddenStateSharding, partition_manager=self.config.partition_manager, ) for idx, block in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = block( hidden_states=hidden_states, mask_info=mask_info, position_ids=position_ids, mode=mode, cache_view=past_key_values.views[idx], cache_metadata=cache_metadata, output_attentions=output_attentions, output_router_logits=output_router_logits, frequencies=self.frequencies, ) hidden_states = layer_outputs.hidden_states if output_attentions: all_self_attns += (layer_outputs.attention_weight,) if output_router_logits: all_router_logits += (layer_outputs.router_logits,) past_key_values[idx] = layer_outputs.cache_view hidden_states = self.norm(hidden_states) return MoeModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, past_key_values=past_key_values, )
[docs] def get_encoder(self): """ Returns the encoder part of the model's graph definition. Decoder-Only models don't have an encoder. """ raise NotImplementedError("This is a decoder-only model and does not have an encoder.")
[docs] def get_decoder(self): """ Returns the decoder part of the model's graph definition. """ return self
[docs] def get_lm_head(self): """ Returns the language model head of the module. Base Models don't have a Language Model Head. """ raise NotImplementedError("The base model does not have a language model head.")
[docs] def get_embedding(self): """ Returns the embedding layer of the module. """ return self.embed_tokens
[docs]@register_module(TaskType.CAUSAL_LM, config=GptOssConfig, model_type="gpt_oss") class GptOssForCausalLM(BaseCausalLMModule[GptOssModel, GptOssConfig]): """GPT-OSS model with a Causal Language Modeling head. This model consists of the base GPT-OSS transformer (GptOssModel) followed by a language modeling head for next token prediction. Supports MoE with auxiliary loss. Type Parameters: GptOssModel: The base model type GptOssConfig: The configuration type """ _task_type = TaskType.CAUSAL_LM _model_type = "gpt_oss" _config_class = GptOssConfig def __init__( self, config: GptOssConfig, dtype: jnp.dtype = jnp.bfloat16, param_dtype: jnp.dtype = jnp.bfloat16, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): """Initialize the GPT-OSS Causal LM module. Args: config: Model configuration dtype: Data type for computations param_dtype: Data type for parameters precision: Precision setting for JAX operations rngs: Random number generators """ super().__init__( config=config, base_model_class=GptOssModel, base_model_name="model", dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, lm_head_bias=False, router_aux_loss_coef=config.router_aux_loss_coef, ) def __call__( self, input_ids: Int[Array, "batch seq_len"] | None = None, inputs_embeds: Float[Array, "batch seq_len hidden_dim"] | None = None, attention_mask: Bool[Array, "batch seq_len"] | None = None, mask_info: MaskInfo | None = None, position_ids: Int[Array, "batch seq_len"] | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, output_router_logits: bool | None = None, mode: common_types.RUNTIME_MODE_TYPES | None = None, # type:ignore past_key_values: TransformerCache | RaggedPagesCache | None = None, cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None, apply_lm_head: bool = True, ) -> MoeCausalLMOutput: """Forward pass for GPT-OSS MoE model. Args: input_ids: Input token IDs inputs_embeds: Input embeddings (alternative to input_ids) attention_mask: Mask to avoid attention on padding tokens position_ids: Position indices for tokens output_attentions: Whether to return attention weights output_hidden_states: Whether to return hidden states output_router_logits: Whether to return router logits mode: Runtime mode (train, eval, etc.) past_key_values: Cache containing precomputed key/value states cache_metadata: Metadata for cache handling apply_lm_head: Whether to apply the LM head Returns: MoeCausalLMOutput with logits, aux_loss, router_logits, etc. """ def _aux_loss_fn(outputs, attention_mask): """Custom auxiliary loss for GPT-OSS.""" if outputs.router_logits is None: return None return auxiliary_load_balancing_loss_func( gate_logits=outputs.router_logits, num_experts=self.config.num_local_experts, top_k=self.config.num_experts_per_tok, attention_mask=attention_mask, ) return self.forward_moe( input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, mask_info=mask_info, position_ids=position_ids, mode=mode, past_key_values=past_key_values, cache_metadata=cache_metadata, apply_lm_head=apply_lm_head, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, aux_loss_fn=_aux_loss_fn, )
[docs]@register_module(TaskType.SEQUENCE_CLASSIFICATION, config=GptOssConfig, model_type="gpt_oss") class GptOssForSequenceClassification(EasyDeLBaseModule): """GptOss model with a Sequence Classification head. This model consists of the base GptOss transformer (`GptOssModel`) followed by a linear layer (`score`) that projects the transformer's output hidden states (typically the hidden state of the first token) to the number of classes for classification. It also handles the calculation of the auxiliary loss from the MoE layers. Attributes: config (GptOssConfig): Configuration object for the model. dtype (jnp.dtype): Data type for computation. param_dtype (jnp.dtype): Data type for parameters. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. rngs (nn.Rngs): Random number generators. model (GptOssModel): The core GptOss transformer model. score (ParallelLinear): The linear layer for classification. num_experts (int): Total number of experts. num_experts_per_tok (int): Number of experts to route per token. """ def __init__( self, config: GptOssConfig, dtype: jnp.dtype = jnp.bfloat16, param_dtype: jnp.dtype = jnp.bfloat16, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): """Initializes the GptOssForSequenceClassification model. Args: config (GptOssConfig): The configuration object for the GptOss model. Must include `num_labels`. dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32. param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None. rngs (nn.Rngs): Random number generators. Raises: AssertionError: If `config.num_labels` is not defined. """ super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.model = GptOssModel( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) assert hasattr(config, "num_labels"), ( "in order to use `SequenceClassification` Models in `EasyDeL` " "you first need to attach `num_labels` to model `config`" ) self.score = ColumnParallelLinear( self.config.hidden_size, config.num_labels, dtype=dtype, param_dtype=param_dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(stddev=config.initializer_range), precision=self.precision, rngs=rngs, ) def __call__( self, input_ids: Int[Array, "batch seq_len"] | None = None, inputs_embeds: Float[Array, "batch seq_len hidden_dim"] | None = None, attention_mask: Bool[Array, "batch seq_len"] | None = None, mask_info: MaskInfo | None = None, position_ids: Int[Array, "batch seq_len"] | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, output_router_logits: bool | None = None, mode: common_types.RUNTIME_MODE_TYPES | None = None, # type:ignore past_key_values: TransformerCache | RaggedPagesCache | None = None, cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None, ) -> SequenceClassifierOutput: """Forward pass of the GptOssForSequenceClassification model. Args: input_ids (tp.Optional[chex.Array]): Input token IDs. Shape: (batch_size, sequence_length). inputs_embeds (tp.Optional[chex.Array]): Input embeddings. Either `input_ids` or `inputs_embeds` must be provided. attention_mask (tp.Optional[chex.Array]): Mask to avoid performing attention on padding token indices. Shape: (batch_size, sequence_length). position_ids (tp.Optional[chex.Array]): Position indices for the tokens. Shape: (batch_size, sequence_length). segment_ids (tp.Optional[chex.Array]): Segment IDs (unused). output_attentions (tp.Optional[bool]): Whether to return attention weights. Defaults to `config.output_attentions`. output_hidden_states (tp.Optional[bool]): Whether to return hidden states for all layers. Defaults to `config.output_hidden_states`. output_router_logits (tp.Optional[bool]): Whether to return router logits from the MoE layers. Defaults to `config.output_router_logits`. past_key_values (tp.Optional[TransformerCache | RaggedPagesCache]): Precomputed key/value states for attention. cache_metadata (tp.Optional[TransformerMetadata | RaggedPagesMetadata]): Metadata for paged attention. Returns: SequenceClassifierOutput: The model's output. returns a `SequenceClassifierOutput` object containing `logits`, `aux_loss` (optional), `hidden_states` (optional), `attentions` (optional), and `router_logits` (optional). Raises: ValueError: If `config.pad_token_id` is None and `batch_size > 1`. """ transformer_outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, mask_info=mask_info, position_ids=position_ids, mode=mode, past_key_values=past_key_values, cache_metadata=cache_metadata, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, inputs_embeds=inputs_embeds, ) hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: sequence_lengths = jnp.argmax(jnp.equal(input_ids, self.config.pad_token_id).astype("i4"), -1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] else: sequence_lengths = -1 pooled_logits = logits[jnp.arange(batch_size), sequence_lengths] aux_loss = None if output_router_logits and transformer_outputs.router_logits is not None: aux_loss = auxiliary_load_balancing_loss_func( gate_logits=transformer_outputs.router_logits, num_experts=self.config.num_local_experts, top_k=self.config.num_experts_per_tok, attention_mask=attention_mask, ) aux_loss += aux_loss * self.config.router_aux_loss_coef return SequenceClassifierOutput( logits=pooled_logits, past_key_values=past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, aux_loss=aux_loss, )
[docs] def get_encoder(self): """ Returns the encoder part of the model's graph definition. Decoder-Only models don't have an encoder. """ raise NotImplementedError("This is a decoder-only model and does not have an encoder.")
[docs] def get_decoder(self): """ Returns the decoder part of the model's graph definition. """ return self.model.get_decoder()
[docs] def get_lm_head(self): """ Returns the language model head of the module. This model has a sequence classification head, not an LM Head. """ raise NotImplementedError("This model has a sequence classification head, not a language model head.")
[docs] def get_embedding(self): """ Returns the embedding layer of the module. """ return self.model.get_embedding()