easydel.modules.gpt_j.modeling_gpt_j_flax#

class easydel.modules.gpt_j.modeling_gpt_j_flax.GPTJAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

GPT-J Attention module.

This module implements the attention mechanism used in the GPT-J model, including rotary position embeddings.

config#

Configuration object for the model.

Type

GPTJConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

causal#

Whether the attention is causal.

Type

bool

is_cross_attention#

Whether the attention is cross-attention.

Type

bool

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.gpt_j.modeling_gpt_j_flax.GPTJBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

GPT-J Transformer block.

This module represents a single transformer block in the GPT-J model, containing self-attention and MLP sub-layers with residual connections and layer normalization.

config#

Configuration object for the model.

Type

GPTJConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.gpt_j.modeling_gpt_j_flax.GPTJForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

GPT-J model with a language modeling head.

This model extends the base GPTJModel by adding a linear layer on top to predict the next token in a sequence, making it suitable for causal language modeling tasks.

config#

Configuration object for the model.

Type

GPTJConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.gpt_j.modeling_gpt_j_flax.GPTJMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

GPT-J MLP module.

This module implements the feed-forward network used in the GPT-J model.

config#

Configuration object for the model.

Type

GPTJConfig

intermediate_size#

Dimensionality of the intermediate layer.

Type

int

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.gpt_j.modeling_gpt_j_flax.GPTJModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

GPT-J model implementation.

This class implements the main GPT-J transformer model architecture, consisting of an embedding layer, multiple GPTJBlock layers, and a final layer normalization.

config#

Configuration object for the model.

Type

GPTJConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray