easydel.modules.opt.__init__#
- class easydel.modules.opt.__init__.OPTConfig(vocab_size: int = 50272, hidden_size: int = 768, num_hidden_layers: int = 12, ffn_dim: int = 3072, max_position_embeddings: int = 2048, do_layer_norm_before: bool = True, _remove_final_layer_norm: bool = False, word_embed_proj_dim: int = None, dropout: float = 0.1, attention_dropout: float = 0.0, num_attention_heads: int = 12, activation_function: str = 'relu', layerdrop: float = 0.0, init_std: float = 0.02, use_cache: bool = True, pad_token_id: int = 1, bos_token_id: int = 2, eos_token_id: int = 2, enable_bias: bool = True, layer_norm_elementwise_affine: bool = True, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 50272) – Vocabulary size of the OPT model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.
hidden_size (int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.
ffn_dim (int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
max_position_embeddings (int, optional, defaults to 2048) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
do_layer_norm_before (bool, optional, defaults to True) – Whether to perform layer normalization before the attention block.
_remove_final_layer_norm (bool, optional, defaults to False) – Whether to remove the final layer norm.
word_embed_proj_dim (int, optional) – The dimension of the word embedding projection. If not provided, it will default to hidden_size.
dropout (float, optional, defaults to 0.1) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
num_attention_heads (int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.
activation_function (str or function, optional, defaults to “relu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.
layerdrop (float, optional, defaults to 0.0) – The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more details.
init_std (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.
pad_token_id (int, optional, defaults to 1) – The index of the padding token in the vocabulary.
bos_token_id (int, optional, defaults to 2) – The id of the beginning-of-sequence token.
eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.
enable_bias (bool, optional, defaults to True) – Whether to use bias in the linear layers.
layer_norm_elementwise_affine (bool, optional, defaults to True) – Whether to use elementwise affine in the layer normalization layers.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.
- attach_custom_arguments(vocab_size: int = 50272, hidden_size: int = 768, num_hidden_layers: int = 12, ffn_dim: int = 3072, max_position_embeddings: int = 2048, do_layer_norm_before: bool = True, _remove_final_layer_norm: bool = False, word_embed_proj_dim: int = None, dropout: float = 0.1, attention_dropout: float = 0.0, num_attention_heads: int = 12, activation_function: str = 'relu', layerdrop: float = 0.0, init_std: float = 0.02, use_cache: bool = True, pad_token_id: int = 1, bos_token_id: int = 2, eos_token_id: int = 2, enable_bias: bool = True, layer_norm_elementwise_affine: bool = True, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
Attaches custom arguments to the configuration object.
This method allows dynamically adding or overriding configuration attributes. It iterates through the provided arguments and sets them as attributes of the configuration object if they don’t already exist.
- Parameters
vocab_size (int, optional) – Vocabulary size. Defaults to 50272.
hidden_size (int, optional) – Dimensionality of the encoder layers. Defaults to 768.
num_hidden_layers (int, optional) – Number of hidden layers. Defaults to 12.
ffn_dim (int, optional) – Dimensionality of the feed-forward layer. Defaults to 3072.
max_position_embeddings (int, optional) – Maximum sequence length. Defaults to 2048.
do_layer_norm_before (bool, optional) – Whether to apply layer norm before attention. Defaults to True.
_remove_final_layer_norm (bool, optional) – Whether to remove the final layer norm. Defaults to False.
word_embed_proj_dim (int, optional) – Dimension of the word embedding projection. Defaults to hidden_size.
dropout (float, optional) – Dropout probability. Defaults to 0.1.
attention_dropout (float, optional) – Attention dropout probability. Defaults to 0.0.
num_attention_heads (int, optional) – Number of attention heads. Defaults to 12.
activation_function (str, optional) – Activation function name. Defaults to “relu”.
layerdrop (float, optional) – LayerDrop probability. Defaults to 0.0.
init_std (float, optional) – Initialization standard deviation. Defaults to 0.02.
use_cache (bool, optional) – Whether to use key/value cache. Defaults to True.
pad_token_id (int, optional) – Padding token ID. Defaults to 1.
bos_token_id (int, optional) – Beginning-of-sequence token ID. Defaults to 2.
eos_token_id (int, optional) – End-of-sequence token ID. Defaults to 2.
enable_bias (bool, optional) – Whether to use bias in linear layers. Defaults to True.
layer_norm_elementwise_affine (bool, optional) – Whether layer norm uses elementwise affine parameters. Defaults to True.
gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.
**kwargs – Additional keyword arguments to attach.
- get_partition_rules(fully_sharded_data_parallel: bool = True)[source]#
Get the partition rules for the model.
- Parameters
fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.
- Returns
The partition rules.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- keys_to_ignore_at_inference = ['past_key_values']#
- model_type: str = 'opt'#
- class easydel.modules.opt.__init__.OPTForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleOPT Model with a Causal Language Modeling head.
This model consists of the base OPTModel followed by a linear layer (the language modeling head) to predict the next token logits.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- prepare_inputs_for_generation(input_ids, max_length: int, pad_token_id: int, starts: int | None = None, shardings=None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
Sets up the initial inputs required for starting autoregressive generation.
This function initializes the Key-Value cache (past_key_values) using init_cache, calculates the initial position_ids based on the input attention_mask (or assumes a contiguous range if no mask is provided), and prepares an extended attention_mask suitable for caching. It ensures inputs are placed on the correct devices/shards.
- Parameters
input_ids (chex.Array) – The initial sequence of token IDs. Shape (batch_size, seq_length).
max_length (int) – The maximum sequence length that the KV cache should support.
pad_token_id (int) – The ID used for padding tokens. Used to calculate starts if not provided.
starts (int | None) – Optional pre-calculated starting positions (number of leading pads). If None, calculated using compute_prefill_length.
shardings (dict | None) – Optional sharding configuration passed to init_cache.
attention_mask (tp.Optional[chex.Array]) – An optional mask indicating which tokens should be attended to. Shape (batch_size, seq_length).
token_type_ids (tp.Optional[chex.Array]) – Optional segment IDs for models that use them.
- Returns
- A dictionary containing the prepared inputs, typically including:
”past_key_values”: The initialized KV cache.
”attention_mask”: The extended attention mask for generation.
”position_ids”: The calculated initial position IDs.
”token_type_ids”: (Optional) Prepared token type IDs.
This dictionary is then passed through prepare_inputs_for_call.
- Return type
dict
- update_inputs_for_generation(model_outputs, model_kwargs)[source]#
Updates the keyword arguments for the next generation step.
Specifically, it takes the past_key_values from the model_outputs of the current step and updates the model_kwargs with them. It also increments the position_ids by one for the next token prediction.
- Parameters
model_outputs – The output object from the model’s forward pass in the previous step (should contain a past_key_values attribute).
model_kwargs (dict) – The dictionary of keyword arguments used for the model call. This dictionary will be modified in-place or a new one returned.
- Returns
The updated model_kwargs dictionary ready for the next generation step.
- Return type
dict
- class easydel.modules.opt.__init__.OPTModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleBase OPT Model class.
This class represents the core OPT model architecture, consisting primarily of the OPTDecoder.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- decoder#
The OPT decoder stack.
- Type