easydel.modules.gpt_neox.modeling_gpt_neox_flax#
- class easydel.modules.gpt_neox.modeling_gpt_neox_flax.GPTNeoXAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModuleGPT-NeoX Attention module.
This module implements the attention mechanism used in the GPT-NeoX model, including rotary position embeddings and parallel linear layers for QKV.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.gpt_neox.modeling_gpt_neox_flax.GPTNeoXBlock(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleGPT-NeoX Transformer block.
This module represents a single transformer block in the GPT-NeoX model, containing self-attention and MLP sub-layers with residual connections and layer normalization. It supports both standard and parallel residual connections.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.gpt_neox.modeling_gpt_neox_flax.GPTNeoXForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGPT-NeoX model with a language modeling head.
This model extends the base GPTNeoXModel by adding a linear layer on top to predict the next token in a sequence, making it suitable for causal language modeling tasks.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.gpt_neox.modeling_gpt_neox_flax.GPTNeoXMlp(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleGPT-NeoX MLP module.
This module implements the feed-forward network used in the GPT-NeoX model.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.gpt_neox.modeling_gpt_neox_flax.GPTNeoXModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGPT-NeoX model implementation.
This class implements the main GPT-NeoX transformer model architecture, consisting of an embedding layer, multiple GPTNeoXBlock layers, and a final layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray