Source code for easydel.modules.openelm.openelm_configuration

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import typing as tp
from numbers import Number

from jax import numpy as jnp
from jax.sharding import PartitionSpec

from easydel.infra.base_module import EasyDeLBaseConfig
from easydel.infra.etils import EasyDeLGradientCheckPointers
from easydel.infra.factory import register_config


[docs]def make_divisible( v: tp.Union[float, int], divisor: tp.Optional[int] = 8, min_value: tp.Optional[tp.Union[float, int]] = None, ) -> tp.Union[float, int]: """This function is taken from the original tf repo. It ensures that all layers have a channel number that is divisible by the divisor It can be seen at: https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62 Args: v: input value divisor: default to 8 min_value: minimum divisor value Returns: new_v: new divisible value """ if min_value is None: min_value = divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < 0.9 * v: new_v += divisor return new_v
[docs]def compute_heads(model_dim: int, head_dim: int) -> int: """Compute the number of heads. Args: model_dim: Model dimension. head_dim: Head dimension. Returns: An integer denoting number of heads in multi-head attention is returned. Raises: ValueError: if model dimension is not divisible by head dimension. """ if model_dim % head_dim == 0: return model_dim // head_dim else: raise ValueError( f"Model dimension should be divisible by head dimension. Got: {model_dim} and {head_dim}." )
[docs]@register_config("openelm") class OpenELMConfig(EasyDeLBaseConfig): """ Configuration objects inherit from [`EasyDeLBaseConfig`] and can be used to control the model outputs. Read the documentation from [`EasyDeLBaseConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 32000): Vocabulary size of the OpenELM model. Defines the number of different tokens that can be represented by the `inputs_ids` passed to the forward method. max_context_length (`int`, *optional*, defaults to 2048): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096). num_transformer_layers (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. model_dim (`int`, *optional*, defaults to 2048): Dimensionality of the encoder layers and the pooler layer. head_dim (`int`, *optional*, defaults to 128): Dimensionality of the attention heads. qkv_multipliers (`float` or `list` of `float`, *optional*, defaults to 1.0): The multiplier for the query, key, and value projections. num_query_heads (`int`, *optional*): Number of query heads. If not provided, it will be calculated based on `model_dim` and `head_dim`. num_gqa_groups (`int`, *optional*, defaults to 1): Number of GQA (Grouped Query Attention) groups. ffn_multipliers (`float` or `list` of `float`, *optional*, defaults to 4.0): The multiplier for the feed-forward network. ffn_with_glu (`bool`, *optional*, defaults to `True`): Whether to use a gated linear unit (GLU) in the feed-forward network. ffn_dim_divisor (`int`, *optional*, defaults to 256): The divisor for the feed-forward network dimension. activation_fn_name (`str`, *optional*, defaults to `"swish"`): The activation function to use. normalization_layer_name (`str`, *optional*, defaults to `"rms_norm"`): The normalization layer to use. normalize_qk_projections (`bool`, *optional*, defaults to `False`): Whether to normalize the query and key projections. share_input_output_layers (`bool`, *optional*, defaults to `False`): Whether to share the input and output layers. rope_freq_constant (`int`, *optional*, defaults to 10000): The frequency constant for Rotary Position Embeddings (RoPE). rope_max_length (`int`, *optional*, defaults to 4096): The maximum length for RoPE. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. bos_token_id (`int`, *optional*, defaults to 1): The id of the *beginning-of-sequence* token. eos_token_id (`int`, *optional*, defaults to 2): The id of the *end-of-sequence* token. rope_scaling (`tp.Dict[str, tp.Union[str, float]]`, *optional*): The configuration for rope scaling. gradient_checkpointing (`str`, *optional*, defaults to `"nothing_saveable"`): The gradient checkpointing configuration. use_scan_mlp (`bool`, *optional*, defaults to `False`): Whether to use the scan implementation for the MLP. scan_mlp_chunk_size (`int`, *optional*, defaults to 1024): The chunk size to use when scanning the MLP. bits (`int`, *optional*): The number of bits to quantize the model to. """ model_type: str = "openelm" attribute_map = {"tie_word_embedding": "share_input_output_layers"} def __init__( self, vocab_size: int = 32000, max_context_length: int = 2048, num_transformer_layers: int = 12, model_dim: int = 2048, head_dim: int = 128, qkv_multipliers: tp.Union[Number, tp.List[Number]] = 1.0, num_query_heads: tp.Union[int, None] = None, num_gqa_groups: int = 1, ffn_multipliers: tp.Union[Number, tp.List[Number]] = 4.0, ffn_with_glu: bool = True, ffn_dim_divisor: int = 256, activation_fn_name: str = "swish", normalization_layer_name: str = "rms_norm", normalize_qk_projections: bool = False, share_input_output_layers: bool = False, rope_freq_constant: int = 10000, rope_max_length: int = 4096, initializer_range: float = 0.02, use_cache: bool = True, bos_token_id: int = 1, eos_token_id: int = 2, rope_scaling: tp.Dict[str, tp.Union[str, float]] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, bits: tp.Optional[int] = None, **kwargs, ): """The __init__ function is called when the class is instantiated. It allows the class to initialize the attributes of a class. The self parameter is a reference to the current instance of the class, and is used to access variables that belong to the class. Args: self: Represent the instance of the class vocab_size: Define the size of the vocabulary hidden_size: Determine the size of the embedding layers intermediate_size: Define the size of the intermediate layer in each transformer block num_hidden_layers: Determine the number of layers in the encoder and decoder num_attention_heads: Determine the number of attention heads in each layer num_key_value_heads: Specify the number of heads for key and value hidden_act: Specify the activation function used in the hidden layers max_position_embeddings: Set the maximum length of the sequence initializer_range: Initialize the weights of the model rms_norm_eps: Avoid division by zero in the rms normalization use_cache: Determine whether to use the cache in the decoder pad_token_id: Specify the token id of the padding token bos_token_id: Specify the beginning of sentence token id eos_token_id: Specify the end of sentence token tie_word_embeddings: Tie the word embeddings and the output layer rope_theta: Control the number of tokens in a rope sliding_window: Control the number of tokens that are processed in parallel gradient_checkpointing: str: Specify whether to use gradient checkpointing use_scan_mlp: bool: Determine whether or not to use the scan_mlp function scan_mlp_chunk_size: int: Specify the chunk size of the scan mlp number_rep_kv: int: Specify the number of times to repeat the key and value vectors attention_dropout: float: Set the dropout rate for the attention layer bits: tp.Optional[int]: Specify the number of bits used for quantization axis_dims: tp.Sequence[int]: Specify the dimension of each axis axis_names: tp.Sequence[str]: Specify the names of each axis in the tensor &quot;mp&quot;): Define the maximum position embeddings attention_bias: bool: when ever to use attention_bias **kwargs: Pass a variable number of keyword arguments to a function Returns: An instance of the class """ self.vocab_size = vocab_size self.max_context_length = max_context_length self.num_transformer_layers = num_transformer_layers self.model_dim = model_dim self.head_dim = head_dim self.qkv_multipliers = qkv_multipliers self.num_gqa_groups = num_gqa_groups self.ffn_multipliers = ffn_multipliers self.ffn_with_glu = ffn_with_glu self.ffn_dim_divisor = ffn_dim_divisor self.activation_fn_name = activation_fn_name self.normalization_layer_name = normalization_layer_name self.normalize_qk_projections = normalize_qk_projections self.share_input_output_layers = share_input_output_layers self.rope_freq_constant = rope_freq_constant self.rope_max_length = rope_max_length self.num_query_heads = ( compute_heads(model_dim=model_dim, head_dim=head_dim) if num_query_heads is None else num_query_heads ) self.initializer_range = initializer_range self.bits = bits self.initializer_range = initializer_range self.use_cache = use_cache self.rope_scaling = rope_scaling self.gradient_checkpointing = gradient_checkpointing self.use_scan_mlp = use_scan_mlp self.scan_mlp_chunk_size = scan_mlp_chunk_size super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, use_scan_mlp=use_scan_mlp, scan_mlp_chunk_size=scan_mlp_chunk_size, bits=bits, **kwargs, ) self.__post_init__()
[docs] def get_partition_rules(self, *args, **kwargs): """ Get the partition rules for the model. Returns: `tp.Tuple[tp.Tuple[str, PartitionSpec]]`: The partition rules. """ return ( ("model/embed_tokens/embedding", PartitionSpec("tp", ("fsdp", "sp"))), ( "self_attn/(q_proj|k_proj|v_proj)/kernel", PartitionSpec(("fsdp", "sp"), "tp"), ), ("self_attn/o_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))), ("mlp/gate_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))), ("mlp/down_proj/kernel", PartitionSpec("tp", ("fsdp", "sp"))), ("mlp/up_proj/kernel", PartitionSpec(("fsdp", "sp"), "tp")), ("input_layernorm/kernel", PartitionSpec(None)), ("post_attention_layernorm/kernel", PartitionSpec(None)), ("model/norm/kernel", PartitionSpec(None)), ("lm_head/kernel", PartitionSpec(("fsdp", "sp"), "tp")), (".*", PartitionSpec(None)), )
[docs] @staticmethod def get_weight_decay_exclusions(): return tuple()
[docs] @staticmethod def rng_keys(): return "params", "dropout", "fcm"
def __post_init__(self) -> None: if self.num_gqa_groups is not None: head_multiple_of = self.num_gqa_groups else: head_multiple_of = 2 if isinstance(self.qkv_multipliers, Number): # All attention layers have the same latent dimensions, resulting in uniform allocation of parameters. qkv_dim = make_divisible( self.model_dim * self.qkv_multipliers, # type:ignore divisor=self.head_dim * head_multiple_of, ) query_dims = [int(qkv_dim)] * self.num_transformer_layers elif ( isinstance(self.qkv_multipliers, (tuple, list)) and len(self.qkv_multipliers) == 2 ): # Each attention layer have different latent dimensions assuming qkv_multipliers[0] != qkv_multipliers[1]. # This results in variable allocation of parameters in attention layer. # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623 qkv_multipliers = [ round(v, 2) for v in jnp.linspace( self.qkv_multipliers[0], self.qkv_multipliers[1], num=self.num_transformer_layers, dtype=float, ) ] # Make sure that scaled model dimension is divisible by scaled head dimension. query_dims = [ int( make_divisible(self.model_dim * m, divisor=self.head_dim * head_multiple_of) ) for m in qkv_multipliers ] else: raise NotImplementedError( f"QKV multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}." ) # compute the number of query, key, and value heads # For multi-head and multi-query attention, the number of heads for query, key, and value are the same. # For group query attention, the number of key and value heads are the same. self.num_query_heads = [ int(compute_heads(q_dim, self.head_dim)) for q_dim in query_dims ] self.num_kv_heads = [ q_heads // self.num_gqa_groups for q_heads in self.num_query_heads ] # Feed-forward network (FFN) multipliers if isinstance(self.ffn_multipliers, Number): # All FFN layers have the same latent dimensions, resulting in uniform allocation of parameters. self.ffn_multipliers = [self.ffn_multipliers] * self.num_transformer_layers elif isinstance(self.ffn_multipliers, (tuple, list)): # Each FFN layer have different latent dimensions assuming ffn_multipliers[0] != ffn_multipliers[1]. # This results in variable allocation of parameters in FFN layer. # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623 if len(self.ffn_multipliers) == 2: self.ffn_multipliers = [ round(v, 2) for v in jnp.linspace( self.ffn_multipliers[0], self.ffn_multipliers[1], num=self.num_transformer_layers, dtype=float, ) ] else: assert len(self.ffn_multipliers) == self.num_transformer_layers, ( f"{len(self.ffn_multipliers)=}!={self.num_transformer_layers=}" ) else: raise NotImplementedError( f"FFN multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}." ) # check num_query_heads divisible by num_kv_heads for every layer for layer_idx in range(len(query_dims)): assert self.num_query_heads[layer_idx] % self.num_kv_heads[layer_idx] == 0 @property def granted_freq_max_position_embedding(self) -> int: return getattr( self, "freq_max_position_embeddings", self.max_context_length, ) @property def granted_mask_max_position_embedding(self) -> int: return getattr( self, "mask_max_position_embeddings", self.max_context_length, )