easydel.modules.gpt2.modeling_gpt2_flax#
- class easydel.modules.gpt2.modeling_gpt2_flax.Conv1D(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleCustom 1D Convolution layer used in GPT-2.
This layer implements a 1D convolution operation often used as a substitute for linear layers in transformer models, particularly in earlier GPT architectures. It performs a matrix multiplication after transposing the kernel.
- in_features#
Dimensionality of the input features.
- Type
int
- out_features#
Dimensionality of the output features.
- Type
int
- use_bias#
Whether to include a bias term. Defaults to True.
- Type
bool
- dtype#
Data type for computations. Defaults to jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Defaults to jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- dot_general#
Custom dot_general function. Defaults to None (uses jax.lax.dot_general).
- Type
tp.Optional[callable]
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.gpt2.modeling_gpt2_flax.GPT2Attention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModuleGPT-2 Attention module.
This module implements the standard multi-head self-attention mechanism used in GPT-2. It supports both self-attention and cross-attention.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- causal#
Whether the attention is causal.
- Type
bool
- is_cross_attention#
Whether the attention is cross-attention.
- Type
bool
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.gpt2.modeling_gpt2_flax.GPT2Block(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleGPT-2 Transformer block.
This module represents a single transformer block in the GPT-2 model, containing self-attention and MLP sub-layers with residual connections and layer normalization. It can optionally include cross-attention layers.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.gpt2.modeling_gpt2_flax.GPT2LMHeadModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGPT-2 model with a language modeling head.
This model extends the base GPT2Model by adding a linear layer on top to predict the next token in a sequence, making it suitable for causal language modeling tasks.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- config: tp.Union[EasyDeLBaseConfig, _CP]#
- dtype: jnp.dtype#
- loss_type: str = 'ForCausalLM'#
- param_dtype: jnp.dtype#
- precision: lax.PrecisionLike#
- rngs: nn.Rngs#
- class easydel.modules.gpt2.modeling_gpt2_flax.GPT2MLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleGPT-2 MLP module.
This module implements the feed-forward network (MLP) used in the GPT-2 model. It consists of two Conv1D layers with a GELU activation in between.
- config#
Configuration object for the model.
- Type
- intermediate_size#
Dimensionality of the intermediate layer.
- Type
int
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.gpt2.modeling_gpt2_flax.GPT2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGPT-2 model implementation.
This class implements the main GPT-2 transformer model architecture, consisting of embedding layers (token and position), multiple GPT2Block layers, and a final layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs