easydel.modules.opt.modeling_opt_flax

Contents

easydel.modules.opt.modeling_opt_flax#

Flax OPT model.

class easydel.modules.opt.modeling_opt_flax.OPTAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

OPT Attention mechanism module.

This module implements the multi-head self-attention mechanism used in the OPT model.

config#

Configuration object for the model.

Type

OPTConfig

embed_dim#

The dimensionality of the embedding layer.

Type

int

num_heads#

The number of attention heads.

Type

int

dropout#

Dropout probability for the attention scores.

Type

float

causal#

Whether to use causal masking.

Type

bool

bias#

Whether to include bias in the linear projections.

Type

bool

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

head_dim#

Dimensionality of each attention head.

Type

int

q_proj#

Linear layer for query projection.

Type

ParallelLinear

k_proj#

Linear layer for key projection.

Type

ParallelLinear

v_proj#

Linear layer for value projection.

Type

ParallelLinear

out_proj#

Linear layer for the output projection.

Type

ParallelLinear

dropout_layer#

Dropout layer applied after attention.

Type

nn.Dropout

attention_module#

The core attention computation module.

Type

AttentionModule

class easydel.modules.opt.modeling_opt_flax.OPTDecoder(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

OPT Decoder stack.

This module comprises the main transformer decoder layers for the OPT model, including token embeddings, positional embeddings, the decoder layers themselves, and optional final layer normalization.

config#

Configuration object for the model.

Type

OPTConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

padding_idx#

Index of the padding token.

Type

int

max_target_positions#

Maximum sequence length the model can handle.

Type

int

embed_scale#

Scaling factor for embeddings (usually 1.0).

Type

float

embed_tokens#

Token embedding layer.

Type

nn.Embed

embed_positions#

Positional embedding layer.

Type

OPTLearnedPositionalEmbedding

project_out#

Optional linear projection layer after embeddings.

Type

ParallelLinear, optional

project_in#

Optional linear projection layer before embeddings.

Type

ParallelLinear, optional

layers#

List of OPT decoder layers.

Type

tp.List[OPTDecoderLayer]

dropout_layer#

Dropout layer applied after embeddings.

Type

nn.Dropout

final_layer_norm#

Optional final layer normalization.

Type

nn.LayerNorm, optional

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

class easydel.modules.opt.modeling_opt_flax.OPTDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

OPT Decoder Layer.

This module represents a single layer in the OPT decoder stack. It consists of a self-attention mechanism, optional layer normalization, a feed-forward network (FFN), and residual connections.

config#

Configuration object for the model.

Type

OPTConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

embed_dim#

Dimensionality of the embedding layer.

Type

int

self_attn#

The self-attention module.

Type

OPTAttention

do_layer_norm_before#

Whether to apply layer normalization before the attention/FFN blocks.

Type

bool

dropout_layer#

Dropout layer applied to the hidden states.

Type

nn.Dropout

activation_fn#

The activation function used in the FFN.

Type

callable

self_attn_layer_norm#

Layer normalization applied before the self-attention module.

Type

nn.LayerNorm

fc1#

The first linear layer of the FFN.

Type

ParallelLinear

fc2#

The second linear layer (output) of the FFN.

Type

ParallelLinear

final_layer_norm#

Layer normalization applied before the FFN module.

Type

nn.LayerNorm

class easydel.modules.opt.modeling_opt_flax.OPTForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

OPT Model with a Causal Language Modeling head.

This model consists of the base OPTModel followed by a linear layer (the language modeling head) to predict the next token logits.

config#

Configuration object for the model.

Type

OPTConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

model#

The base OPT model.

Type

OPTModel

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

get_decoder()[source]#

Gets the decoder module from the model.

get_input_embeddings()[source]#

Gets the input embeddings from the model.

get_output_embeddings()[source]#

Gets the output embeddings (language modeling head).

prepare_inputs_for_generation(input_ids, max_length: int, pad_token_id: int, starts: int | None = None, shardings=None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Sets up the initial inputs required for starting autoregressive generation.

This function initializes the Key-Value cache (past_key_values) using init_cache, calculates the initial position_ids based on the input attention_mask (or assumes a contiguous range if no mask is provided), and prepares an extended attention_mask suitable for caching. It ensures inputs are placed on the correct devices/shards.

Parameters
  • input_ids (chex.Array) – The initial sequence of token IDs. Shape (batch_size, seq_length).

  • max_length (int) – The maximum sequence length that the KV cache should support.

  • pad_token_id (int) – The ID used for padding tokens. Used to calculate starts if not provided.

  • starts (int | None) – Optional pre-calculated starting positions (number of leading pads). If None, calculated using compute_prefill_length.

  • shardings (dict | None) – Optional sharding configuration passed to init_cache.

  • attention_mask (tp.Optional[chex.Array]) – An optional mask indicating which tokens should be attended to. Shape (batch_size, seq_length).

  • token_type_ids (tp.Optional[chex.Array]) – Optional segment IDs for models that use them.

Returns

A dictionary containing the prepared inputs, typically including:
  • ”past_key_values”: The initialized KV cache.

  • ”attention_mask”: The extended attention mask for generation.

  • ”position_ids”: The calculated initial position IDs.

  • ”token_type_ids”: (Optional) Prepared token type IDs.

This dictionary is then passed through prepare_inputs_for_call.

Return type

dict

set_decoder(decoder)[source]#

Sets the decoder module for the model.

set_input_embeddings(value)[source]#

Sets the input embeddings for the model.

set_output_embeddings(new_embeddings)[source]#

Sets the output embeddings (language modeling head).

update_inputs_for_generation(model_outputs, model_kwargs)[source]#

Updates the keyword arguments for the next generation step.

Specifically, it takes the past_key_values from the model_outputs of the current step and updates the model_kwargs with them. It also increments the position_ids by one for the next token prediction.

Parameters
  • model_outputs – The output object from the model’s forward pass in the previous step (should contain a past_key_values attribute).

  • model_kwargs (dict) – The dictionary of keyword arguments used for the model call. This dictionary will be modified in-place or a new one returned.

Returns

The updated model_kwargs dictionary ready for the next generation step.

Return type

dict

class easydel.modules.opt.modeling_opt_flax.OPTLearnedPositionalEmbedding(*args: Any, **kwargs: Any)[source]#

Bases: Embed

Learned positional embedding for OPT.

This module learns positional embeddings up to a maximum specified length. It includes an offset, typically used to account for padding tokens.

offset#

The offset added to position IDs before embedding lookup.

Type

int

class easydel.modules.opt.modeling_opt_flax.OPTModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Base OPT Model class.

This class represents the core OPT model architecture, consisting primarily of the OPTDecoder.

config#

Configuration object for the model.

Type

OPTConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

decoder#

The OPT decoder stack.

Type

OPTDecoder

get_input_embeddings()[source]#

Gets the input embeddings from the model.

set_input_embeddings(value)[source]#

Sets the input embeddings for the model.