easydel.modules.gpt_j.modeling_gpt_j#

class easydel.modules.gpt_j.modeling_gpt_j.GPTJAttention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

GPT-J Attention with partial RoPE.

Inherits from UnifiedAttention. Uses separate Q/K/V projections with partial rotary embeddings.

define_network(config: GPTJConfig, dtype: dtype, param_dtype: dtype, precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset], rngs: Rngs)[source]#

Define GPT-J-specific network with residual dropout.

projection_mapping: ClassVar[dict[str, str]] = {'key_projection': 'k_proj', 'mla_kv_a_layernorm': 'kv_a_layernorm', 'mla_kv_a_proj_with_mqa': 'kv_a_proj_with_mqa', 'mla_kv_b_proj': 'kv_b_proj', 'mla_q_a_layernorm': 'q_a_layernorm', 'mla_q_a_proj': 'q_a_proj', 'mla_q_b_proj': 'q_b_proj', 'mla_q_proj': 'q_proj', 'output_projection': 'out_proj', 'qkv_projection': 'qkv_proj', 'query_projection': 'q_proj', 'value_projection': 'v_proj'}#
class easydel.modules.gpt_j.modeling_gpt_j.GPTJBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

GPT-J Transformer block.

This module represents a single transformer block in the GPT-J model, containing self-attention and MLP sub-layers with residual connections and layer normalization.

config#

Configuration object for the model.

Type

GPTJConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.gpt_j.modeling_gpt_j.GPTJForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[GPTJModel, GPTJConfig]

GPT-J model with a language modeling head.

class easydel.modules.gpt_j.modeling_gpt_j.GPTJMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

GPT-J MLP module.

This module implements the feed-forward network used in the GPT-J model.

config#

Configuration object for the model.

Type

GPTJConfig

intermediate_size#

Dimensionality of the intermediate layer.

Type

int

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.gpt_j.modeling_gpt_j.GPTJModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

GPT-J model implementation.

This class implements the main GPT-J transformer model architecture, consisting of an embedding layer, multiple GPTJBlock layers, and a final layer normalization.

config#

Configuration object for the model.

Type

GPTJConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.