easydel.modules.opt.modeling_opt_flax#
Flax OPT model.
- class easydel.modules.opt.modeling_opt_flax.OPTAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModuleOPT Attention mechanism module.
This module implements the multi-head self-attention mechanism used in the OPT model.
- embed_dim#
The dimensionality of the embedding layer.
- Type
int
- num_heads#
The number of attention heads.
- Type
int
- dropout#
Dropout probability for the attention scores.
- Type
float
- causal#
Whether to use causal masking.
- Type
bool
- bias#
Whether to include bias in the linear projections.
- Type
bool
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- head_dim#
Dimensionality of each attention head.
- Type
int
- q_proj#
Linear layer for query projection.
- Type
- k_proj#
Linear layer for key projection.
- Type
- v_proj#
Linear layer for value projection.
- Type
- out_proj#
Linear layer for the output projection.
- Type
- dropout_layer#
Dropout layer applied after attention.
- Type
nn.Dropout
- attention_module#
The core attention computation module.
- Type
- class easydel.modules.opt.modeling_opt_flax.OPTDecoder(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleOPT Decoder stack.
This module comprises the main transformer decoder layers for the OPT model, including token embeddings, positional embeddings, the decoder layers themselves, and optional final layer normalization.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- padding_idx#
Index of the padding token.
- Type
int
- max_target_positions#
Maximum sequence length the model can handle.
- Type
int
- embed_scale#
Scaling factor for embeddings (usually 1.0).
- Type
float
- embed_tokens#
Token embedding layer.
- Type
nn.Embed
- embed_positions#
Positional embedding layer.
- project_out#
Optional linear projection layer after embeddings.
- Type
ParallelLinear, optional
- project_in#
Optional linear projection layer before embeddings.
- Type
ParallelLinear, optional
- layers#
List of OPT decoder layers.
- Type
tp.List[OPTDecoderLayer]
- dropout_layer#
Dropout layer applied after embeddings.
- Type
nn.Dropout
- final_layer_norm#
Optional final layer normalization.
- Type
nn.LayerNorm, optional
- gradient_checkpointing#
Gradient checkpointing configuration.
- class easydel.modules.opt.modeling_opt_flax.OPTDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleOPT Decoder Layer.
This module represents a single layer in the OPT decoder stack. It consists of a self-attention mechanism, optional layer normalization, a feed-forward network (FFN), and residual connections.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- embed_dim#
Dimensionality of the embedding layer.
- Type
int
- self_attn#
The self-attention module.
- Type
- do_layer_norm_before#
Whether to apply layer normalization before the attention/FFN blocks.
- Type
bool
- dropout_layer#
Dropout layer applied to the hidden states.
- Type
nn.Dropout
- activation_fn#
The activation function used in the FFN.
- Type
callable
- self_attn_layer_norm#
Layer normalization applied before the self-attention module.
- Type
nn.LayerNorm
- fc1#
The first linear layer of the FFN.
- Type
- fc2#
The second linear layer (output) of the FFN.
- Type
- final_layer_norm#
Layer normalization applied before the FFN module.
- Type
nn.LayerNorm
- class easydel.modules.opt.modeling_opt_flax.OPTForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleOPT Model with a Causal Language Modeling head.
This model consists of the base OPTModel followed by a linear layer (the language modeling head) to predict the next token logits.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- prepare_inputs_for_generation(input_ids, max_length: int, pad_token_id: int, starts: int | None = None, shardings=None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
Sets up the initial inputs required for starting autoregressive generation.
This function initializes the Key-Value cache (past_key_values) using init_cache, calculates the initial position_ids based on the input attention_mask (or assumes a contiguous range if no mask is provided), and prepares an extended attention_mask suitable for caching. It ensures inputs are placed on the correct devices/shards.
- Parameters
input_ids (chex.Array) – The initial sequence of token IDs. Shape (batch_size, seq_length).
max_length (int) – The maximum sequence length that the KV cache should support.
pad_token_id (int) – The ID used for padding tokens. Used to calculate starts if not provided.
starts (int | None) – Optional pre-calculated starting positions (number of leading pads). If None, calculated using compute_prefill_length.
shardings (dict | None) – Optional sharding configuration passed to init_cache.
attention_mask (tp.Optional[chex.Array]) – An optional mask indicating which tokens should be attended to. Shape (batch_size, seq_length).
token_type_ids (tp.Optional[chex.Array]) – Optional segment IDs for models that use them.
- Returns
- A dictionary containing the prepared inputs, typically including:
”past_key_values”: The initialized KV cache.
”attention_mask”: The extended attention mask for generation.
”position_ids”: The calculated initial position IDs.
”token_type_ids”: (Optional) Prepared token type IDs.
This dictionary is then passed through prepare_inputs_for_call.
- Return type
dict
- update_inputs_for_generation(model_outputs, model_kwargs)[source]#
Updates the keyword arguments for the next generation step.
Specifically, it takes the past_key_values from the model_outputs of the current step and updates the model_kwargs with them. It also increments the position_ids by one for the next token prediction.
- Parameters
model_outputs – The output object from the model’s forward pass in the previous step (should contain a past_key_values attribute).
model_kwargs (dict) – The dictionary of keyword arguments used for the model call. This dictionary will be modified in-place or a new one returned.
- Returns
The updated model_kwargs dictionary ready for the next generation step.
- Return type
dict
- class easydel.modules.opt.modeling_opt_flax.OPTLearnedPositionalEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
EmbedLearned positional embedding for OPT.
This module learns positional embeddings up to a maximum specified length. It includes an offset, typically used to account for padding tokens.
- offset#
The offset added to position IDs before embedding lookup.
- Type
int
- class easydel.modules.opt.modeling_opt_flax.OPTModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleBase OPT Model class.
This class represents the core OPT model architecture, consisting primarily of the OPTDecoder.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- decoder#
The OPT decoder stack.
- Type