easydel.modules.gpt2.modeling_gpt2#
- class easydel.modules.gpt2.modeling_gpt2.Conv1D(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleCustom 1D Convolution layer used in GPT-2.
This layer implements a 1D convolution operation often used as a substitute for linear layers in transformer models, particularly in earlier GPT architectures. It performs a matrix multiplication after transposing the kernel.
- in_features#
Dimensionality of the input features.
- Type
int
- out_features#
Dimensionality of the output features.
- Type
int
- use_bias#
Whether to include a bias term. Defaults to True.
- Type
bool
- dtype#
Data type for computations. Defaults to jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Defaults to jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- dot_general#
Custom dot_general function. Defaults to None (uses jax.lax.dot_general).
- Type
tp.Optional[callable]
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.gpt2.modeling_gpt2.GPT2Attention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionGPT-2 Attention module.
This module implements the standard multi-head self-attention mechanism used in GPT-2. It supports both self-attention and cross-attention.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- causal#
Whether the attention is causal.
- Type
bool
- is_cross_attention#
Whether the attention is cross-attention.
- Type
bool
- rngs#
Random number generators.
- Type
nn.Rngs
- define_network(config: GPT2Config, dtype: dtype, param_dtype: dtype, precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset], rngs: Rngs) None[source]#
Create GPT-2 specific projection layers.
- class easydel.modules.gpt2.modeling_gpt2.GPT2Block(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleGPT-2 Transformer block.
This module represents a single transformer block in the GPT-2 model, containing self-attention and MLP sub-layers with residual connections and layer normalization. It can optionally include cross-attention layers.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.gpt2.modeling_gpt2.GPT2LMHeadModel(*args: Any, **kwargs: Any)[source]#
Bases:
BaseCausalLMModule[GPT2Model,GPT2Config]GPT-2 model with a language modeling head.
This model extends the base GPT2Model by adding a linear layer on top to predict the next token in a sequence, making it suitable for causal language modeling tasks.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- loss_type: str = 'ForCausalLM'#
- class easydel.modules.gpt2.modeling_gpt2.GPT2MLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleGPT-2 MLP module.
This module implements the feed-forward network (MLP) used in the GPT-2 model. It consists of two Conv1D layers with a GELU activation in between.
- config#
Configuration object for the model.
- Type
- intermediate_size#
Dimensionality of the intermediate layer.
- Type
int
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.gpt2.modeling_gpt2.GPT2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGPT-2 model implementation.
This class implements the main GPT-2 transformer model architecture, consisting of embedding layers (token and position), multiple GPT2Block layers, and a final layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs