easydel.modules.gpt_oss.modeling_gpt_oss#

class easydel.modules.gpt_oss.modeling_gpt_oss.GptOssAttention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

GPT-OSS Attention with sink tokens support.

Inherits from UnifiedAttention. Supports layer-specific sliding windows and sink tokens for improved attention.

class easydel.modules.gpt_oss.modeling_gpt_oss.GptOssDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Single GPT-OSS decoder block with attention and expert MLP.

class easydel.modules.gpt_oss.modeling_gpt_oss.GptOssExperts(*args: Any, **kwargs: Any)[source]#

Bases: Module

Grouped expert feed-forward network used inside GPT-OSS MoE layers.

reform_param: ClassVar = {'down_proj$': {'inverse_spliter': <function GptOssExperts.<lambda>>, 'splits': [{'name': 'down_proj.kernel', 'spliter': <function GptOssExperts.<lambda>>}]}, 'down_proj_bias$': {'inverse_spliter': <function GptOssExperts.<lambda>>, 'splits': [{'name': 'down_proj.bias', 'spliter': <function GptOssExperts.<lambda>>}]}, 'gate_up_proj$': {'inverse_spliter': <function GptOssExperts.<lambda>>, 'splits': [{'name': 'gate_proj.kernel', 'spliter': <function GptOssExperts.<lambda>>}, {'name': 'up_proj.kernel', 'spliter': <function GptOssExperts.<lambda>>}]}, 'gate_up_proj_bias$': {'inverse_spliter': <function GptOssExperts.<lambda>>, 'splits': [{'name': 'gate_proj.bias', 'spliter': <function GptOssExperts.<lambda>>}, {'name': 'up_proj.bias', 'spliter': <function GptOssExperts.<lambda>>}]}}#
class easydel.modules.gpt_oss.modeling_gpt_oss.GptOssForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[GptOssModel, GptOssConfig]

GPT-OSS model with a Causal Language Modeling head.

This model consists of the base GPT-OSS transformer (GptOssModel) followed by a language modeling head for next token prediction. Supports MoE with auxiliary loss.

Type Parameters:

GptOssModel: The base model type GptOssConfig: The configuration type

class easydel.modules.gpt_oss.modeling_gpt_oss.GptOssForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

GptOss model with a Sequence Classification head.

This model consists of the base GptOss transformer (GptOssModel) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the first token) to the number of classes for classification. It also handles the calculation of the auxiliary loss from the MoE layers.

config#

Configuration object for the model.

Type

GptOssConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core GptOss transformer model.

Type

GptOssModel

score#

The linear layer for classification.

Type

ParallelLinear

num_experts#

Total number of experts.

Type

int

num_experts_per_tok#

Number of experts to route per token.

Type

int

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. This model has a sequence classification head, not an LM Head.

class easydel.modules.gpt_oss.modeling_gpt_oss.GptOssMLP(*args: Any, **kwargs: Any)[source]#

Bases: BaseMoeModule

Mixture-of-experts MLP combining the router and shared experts.

class easydel.modules.gpt_oss.modeling_gpt_oss.GptOssModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base GptOss model transformer.

This class represents the core transformer architecture of the GptOss model, consisting of an embedding layer, multiple GptOssDecoderLayer layers (with sparse MoE), and a final layer normalization.

config#

Configuration object for the model.

Type

GptOssConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[GptOssDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.

class easydel.modules.gpt_oss.modeling_gpt_oss.GptOssRMSNorm(*args: Any, **kwargs: Any)[source]#

Bases: RMSNorm