easydel.modules.gpt2.modeling_gpt2_flax#

class easydel.modules.gpt2.modeling_gpt2_flax.Conv1D(*args: Any, **kwargs: Any)[source]#

Bases: Module

Custom 1D Convolution layer used in GPT-2.

This layer implements a 1D convolution operation often used as a substitute for linear layers in transformer models, particularly in earlier GPT architectures. It performs a matrix multiplication after transposing the kernel.

in_features#

Dimensionality of the input features.

Type

int

out_features#

Dimensionality of the output features.

Type

int

use_bias#

Whether to include a bias term. Defaults to True.

Type

bool

dtype#

Data type for computations. Defaults to jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Defaults to jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

dot_general#

Custom dot_general function. Defaults to None (uses jax.lax.dot_general).

Type

tp.Optional[callable]

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.gpt2.modeling_gpt2_flax.GPT2Attention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

GPT-2 Attention module.

This module implements the standard multi-head self-attention mechanism used in GPT-2. It supports both self-attention and cross-attention.

config#

Configuration object for the model.

Type

GPT2Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

causal#

Whether the attention is causal.

Type

bool

is_cross_attention#

Whether the attention is cross-attention.

Type

bool

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.gpt2.modeling_gpt2_flax.GPT2Block(*args: Any, **kwargs: Any)[source]#

Bases: Module

GPT-2 Transformer block.

This module represents a single transformer block in the GPT-2 model, containing self-attention and MLP sub-layers with residual connections and layer normalization. It can optionally include cross-attention layers.

config#

Configuration object for the model.

Type

GPT2Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.gpt2.modeling_gpt2_flax.GPT2LMHeadModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

GPT-2 model with a language modeling head.

This model extends the base GPT2Model by adding a linear layer on top to predict the next token in a sequence, making it suitable for causal language modeling tasks.

config#

Configuration object for the model.

Type

GPT2Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

config: tp.Union[EasyDeLBaseConfig, _CP]#
dtype: jnp.dtype#
loss_type: str = 'ForCausalLM'#
param_dtype: jnp.dtype#
precision: lax.PrecisionLike#
rngs: nn.Rngs#
class easydel.modules.gpt2.modeling_gpt2_flax.GPT2MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

GPT-2 MLP module.

This module implements the feed-forward network (MLP) used in the GPT-2 model. It consists of two Conv1D layers with a GELU activation in between.

config#

Configuration object for the model.

Type

GPT2Config

intermediate_size#

Dimensionality of the intermediate layer.

Type

int

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.gpt2.modeling_gpt2_flax.GPT2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

GPT-2 model implementation.

This class implements the main GPT-2 transformer model architecture, consisting of embedding layers (token and position), multiple GPT2Block layers, and a final layer normalization.

config#

Configuration object for the model.

Type

GPT2Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs