easydel.modules.gpt_oss.modeling_gpt_oss#
- class easydel.modules.gpt_oss.modeling_gpt_oss.GptOssAttention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionGPT-OSS Attention with sink tokens support.
Inherits from UnifiedAttention. Supports layer-specific sliding windows and sink tokens for improved attention.
- class easydel.modules.gpt_oss.modeling_gpt_oss.GptOssDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleSingle GPT-OSS decoder block with attention and expert MLP.
- class easydel.modules.gpt_oss.modeling_gpt_oss.GptOssExperts(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleGrouped expert feed-forward network used inside GPT-OSS MoE layers.
- reform_param: ClassVar = {'down_proj$': {'inverse_spliter': <function GptOssExperts.<lambda>>, 'splits': [{'name': 'down_proj.kernel', 'spliter': <function GptOssExperts.<lambda>>}]}, 'down_proj_bias$': {'inverse_spliter': <function GptOssExperts.<lambda>>, 'splits': [{'name': 'down_proj.bias', 'spliter': <function GptOssExperts.<lambda>>}]}, 'gate_up_proj$': {'inverse_spliter': <function GptOssExperts.<lambda>>, 'splits': [{'name': 'gate_proj.kernel', 'spliter': <function GptOssExperts.<lambda>>}, {'name': 'up_proj.kernel', 'spliter': <function GptOssExperts.<lambda>>}]}, 'gate_up_proj_bias$': {'inverse_spliter': <function GptOssExperts.<lambda>>, 'splits': [{'name': 'gate_proj.bias', 'spliter': <function GptOssExperts.<lambda>>}, {'name': 'up_proj.bias', 'spliter': <function GptOssExperts.<lambda>>}]}}#
- class easydel.modules.gpt_oss.modeling_gpt_oss.GptOssForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
BaseCausalLMModule[GptOssModel,GptOssConfig]GPT-OSS model with a Causal Language Modeling head.
This model consists of the base GPT-OSS transformer (GptOssModel) followed by a language modeling head for next token prediction. Supports MoE with auxiliary loss.
- Type Parameters:
GptOssModel: The base model type GptOssConfig: The configuration type
- class easydel.modules.gpt_oss.modeling_gpt_oss.GptOssForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGptOss model with a Sequence Classification head.
This model consists of the base GptOss transformer (GptOssModel) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the first token) to the number of classes for classification. It also handles the calculation of the auxiliary loss from the MoE layers.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core GptOss transformer model.
- Type
- score#
The linear layer for classification.
- Type
- num_experts#
Total number of experts.
- Type
int
- num_experts_per_tok#
Number of experts to route per token.
- Type
int
- class easydel.modules.gpt_oss.modeling_gpt_oss.GptOssMLP(*args: Any, **kwargs: Any)[source]#
Bases:
BaseMoeModuleMixture-of-experts MLP combining the router and shared experts.
- class easydel.modules.gpt_oss.modeling_gpt_oss.GptOssModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base GptOss model transformer.
This class represents the core transformer architecture of the GptOss model, consisting of an embedding layer, multiple GptOssDecoderLayer layers (with sparse MoE), and a final layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[GptOssDecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.