easydel.layers.rotary_embedding#
- easydel.layers.rotary_embedding.AVAILABLE_ROPE_TYPES = {'deepseek_yarn': {'__abstractmethods__': frozenset({}), '__call__': <function DeepseekScalingRotaryEmbedding.__call__>, '__doc__': '\n\tRotaryEmbedding implementing a YaRN-like scaling method, potentially from Deepseek models.\n\n\tUses YaRN parameters (`beta_fast`, `beta_slow`, `extrapolation_factor`) and includes\n\tadditional m-scale parameters (`mscale`, `mscale_all_dim`). This version has a custom\n\t`__call__` method differing slightly from `apply_basic_rope`.\n\n\tAttributes:\n\t head_size (int): Dimension of each attention head.\n\t rotary_dim (int): Dimension subjected to rotary embedding.\n\t max_position_embeddings (int): Original maximum sequence length before scaling.\n\t base (int): Base for frequency calculation.\n\t is_neox_style (bool): Use Neox rotation if True, GPT-J otherwise.\n\t dtype (jnp.dtype): Data type for embeddings.\n\t scaling_factor (float): Primary scaling factor.\n\t extrapolation_factor (float): YaRN extrapolation factor.\n\t attn_factor (float): Attention scaling factor.\n\t beta_fast (int): YaRN parameter.\n\t beta_slow (int): YaRN parameter.\n\t mscale (float): Parameter for m-scale calculation.\n\t mscale_all_dim (float): Parameter for m-scale calculation.\n\t', '__init__': <function DeepseekScalingRotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>}, 'default': {'__abstractmethods__': frozenset({}), '__call__': <function RotaryEmbedding.__call__>, '__doc__': '\n\tStandard Rotary Positional Embedding (RoPE) module.\n\n\tAttributes:\n\t head_size (int): The dimension size of each attention head.\n\t rotary_dim (int): The dimension size of the rotary embeddings applied. Can be <= head_size.\n\t max_position_embeddings (int): The maximum sequence length the model can handle.\n\t base (int): The base value for calculating frequencies.\n\t is_neox_style (bool): Flag indicating whether to use Neox-style rotation.\n\t dtype (jnp.dtype): Data type for computations.\n\t', '__init__': <function RotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>}, 'dynamic': {'__abstractmethods__': frozenset({}), '__call__': <function DynamicNTKScalingRotaryEmbedding.__call__>, '__doc__': '\n\tRotaryEmbedding extended with Dynamic NTK scaling.\n\n\tDynamically adjusts the `base` parameter based on the scaling factor.\n\n\tAttributes:\n\t scaling_factor (float): The scaling factor applied to sequence length and base calculation.\n\t Inherits other attributes from RotaryEmbedding.\n\t', '__init__': <function DynamicNTKScalingRotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>}, 'linear': {'__abstractmethods__': frozenset({}), '__call__': <function LinearScalingRotaryEmbedding.__call__>, '__doc__': '\n\tRotaryEmbedding extended with Linear Scaling.\n\n\tLinearly scales the position indices before calculating frequencies.\n\n\tAttributes:\n\t scaling_factors (tp.Union[tp.List[float], float]): The factor(s) to scale positions by.\n\t Inherits other attributes from RotaryEmbedding.\n\t', '__init__': <function LinearScalingRotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>}, 'llama3': {'__abstractmethods__': frozenset({}), '__call__': <function Llama3RotaryEmbedding.__call__>, '__doc__': '\n\tRotaryEmbedding implementing the Llama-3 scaling method.\n\n\tAdjusts frequencies based on wavelength thresholds (`low_freq_factor`, `high_freq_factor`)\n\tand applies an overall scaling factor.\n\n\tAttributes:\n\t scaling_factor (float): Overall scaling factor.\n\t low_freq_factor (float): Factor related to low frequency wavelength threshold.\n\t high_freq_factor (float): Factor related to high frequency wavelength threshold.\n\t orig_max_position (int): Original maximum sequence length before scaling.\n\t Inherits other attributes from RotaryEmbedding.\n\t', '__init__': <function Llama3RotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>}, 'longrope': {'__abstractmethods__': frozenset({}), '__call__': <function Phi3LongRoPEScaledRotaryEmbedding.__call__>, '__doc__': '\n\tRotaryEmbedding using the Phi-3 LongRoPE scaling method.\n\n\tApplies different frequency scaling factors (`short_factor`, `long_factor`)\n\tdepending on the target sequence length relative to the original maximum.\n\tRequires `rotary_dim` to be equal to `head_size`.\n\n\tAttributes:\n\t head_size (int): Dimension of each attention head. Must equal rotary_dim.\n\t rotary_dim (int): Dimension subjected to rotary embedding. Must equal head_size.\n\t max_position_embeddings (int): The target maximum sequence length after scaling.\n\t original_max_position_embeddings (int): Original maximum sequence length before scaling.\n\t base (int): Base for frequency calculation.\n\t is_neox_style (bool): Flag indicating whether Neox-style rotation is assumed (used by `apply_phi3_rope`).\n\t dtype (jnp.dtype): Data type for computations.\n\t short_factor (tp.List[float]): Scaling factors applied when target length <= original max length.\n\t long_factor (tp.List[float]): Scaling factors applied when target length > original max length.\n\t', '__init__': <function Phi3LongRoPEScaledRotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>}, 'yarn': {'__abstractmethods__': frozenset({}), '__call__': <function YaRNScalingRotaryEmbedding.__call__>, '__doc__': '\n\tRotaryEmbedding extended with the YaRN (Yet another RoPE extensioN method) scaling.\n\n\tCombines interpolation and extrapolation with frequency correction and magnitude scaling.\n\n\tAttributes:\n\t scaling_factor (tp.Union[float, int]): The primary scaling factor for context length.\n\t extrapolation_factor (float): Controls the strength of extrapolation correction.\n\t attn_factor (float): Scales the output attention values.\n\t beta_fast (int): YaRN parameter for high-frequency dimensions correction range.\n\t beta_slow (int): YaRN parameter for low-frequency dimensions correction range.\n\t Inherits other attributes from RotaryEmbedding. Note: `max_position_embeddings`\n\t in the parent init likely refers to the *original* max length for YaRN calculations.\n\t', '__init__': <function YaRNScalingRotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>}}#
A dictionary to store registered RoPE (Rotary Position Embedding) types and their configurations.
- class easydel.layers.rotary_embedding.DeepseekScalingRotaryEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleRotaryEmbedding implementing a YaRN-like scaling method, potentially from Deepseek models.
Uses YaRN parameters (beta_fast, beta_slow, extrapolation_factor) and includes additional m-scale parameters (mscale, mscale_all_dim). This version has a custom __call__ method differing slightly from apply_basic_rope.
- head_size#
Dimension of each attention head.
- Type
int
- rotary_dim#
Dimension subjected to rotary embedding.
- Type
int
- max_position_embeddings#
Original maximum sequence length before scaling.
- Type
int
- base#
Base for frequency calculation.
- Type
int
- is_neox_style#
Use Neox rotation if True, GPT-J otherwise.
- Type
bool
- dtype#
Data type for embeddings.
- Type
jnp.dtype
- scaling_factor#
Primary scaling factor.
- Type
float
- extrapolation_factor#
YaRN extrapolation factor.
- Type
float
- attn_factor#
Attention scaling factor.
- Type
float
- beta_fast#
YaRN parameter.
- Type
int
- beta_slow#
YaRN parameter.
- Type
int
- mscale#
Parameter for m-scale calculation.
- Type
float
- mscale_all_dim#
Parameter for m-scale calculation.
- Type
float
- class easydel.layers.rotary_embedding.DynamicNTKScalingRotaryEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
RotaryEmbeddingRotaryEmbedding extended with Dynamic NTK scaling.
Dynamically adjusts the base parameter based on the scaling factor.
- scaling_factor#
The scaling factor applied to sequence length and base calculation.
- Type
float
- Inherits other attributes from RotaryEmbedding.
- class easydel.layers.rotary_embedding.LinearScalingRotaryEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
RotaryEmbeddingRotaryEmbedding extended with Linear Scaling.
Linearly scales the position indices before calculating frequencies.
- scaling_factors#
The factor(s) to scale positions by.
- Type
tp.Union[tp.List[float], float]
- Inherits other attributes from RotaryEmbedding.
- class easydel.layers.rotary_embedding.Llama3RotaryEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
RotaryEmbeddingRotaryEmbedding implementing the Llama-3 scaling method.
Adjusts frequencies based on wavelength thresholds (low_freq_factor, high_freq_factor) and applies an overall scaling factor.
- scaling_factor#
Overall scaling factor.
- Type
float
- low_freq_factor#
Factor related to low frequency wavelength threshold.
- Type
float
- high_freq_factor#
Factor related to high frequency wavelength threshold.
- Type
float
- orig_max_position#
Original maximum sequence length before scaling.
- Type
int
- Inherits other attributes from RotaryEmbedding.
- class easydel.layers.rotary_embedding.Phi3LongRoPEScaledRotaryEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleRotaryEmbedding using the Phi-3 LongRoPE scaling method.
Applies different frequency scaling factors (short_factor, long_factor) depending on the target sequence length relative to the original maximum. Requires rotary_dim to be equal to head_size.
- head_size#
Dimension of each attention head. Must equal rotary_dim.
- Type
int
- rotary_dim#
Dimension subjected to rotary embedding. Must equal head_size.
- Type
int
- max_position_embeddings#
The target maximum sequence length after scaling.
- Type
int
- original_max_position_embeddings#
Original maximum sequence length before scaling.
- Type
int
- base#
Base for frequency calculation.
- Type
int
- is_neox_style#
Flag indicating whether Neox-style rotation is assumed (used by apply_phi3_rope).
- Type
bool
- dtype#
Data type for computations.
- Type
jnp.dtype
- short_factor#
Scaling factors applied when target length <= original max length.
- Type
tp.List[float]
- long_factor#
Scaling factors applied when target length > original max length.
- Type
tp.List[float]
- class easydel.layers.rotary_embedding.RopeConfig(rope_type: str = 'default', factor: Optional[float] = None, low_freq_factor: Optional[float] = None, high_freq_factor: Optional[float] = None, original_max_position_embeddings: Optional[int] = None, long_factor: Optional[float] = None, short_factor: Optional[float] = None, long_mscale: Optional[float] = None, short_mscale: Optional[float] = None)[source]#
Bases:
MappingConfiguration class for RoPE (Rotary Position Embedding) parameters.
Stores the configuration related to RoPE type and its scaling parameters, making it easy to manage and pass around RoPE settings.
- rope_type#
The type of RoPE scaling to use (e.g., โdefaultโ, โlinearโ, โyarnโ, โllama3โ). Defaults to โdefaultโ.
- Type
str
- factor#
General scaling factor used by some types (linear, dynamic, yarn, llama3).
- Type
tp.Optional[float]
- low_freq_factor#
Specific factor for Llama3 scaling.
- Type
tp.Optional[float]
- high_freq_factor#
Specific factor for Llama3 scaling.
- Type
tp.Optional[float]
- original_max_position_embeddings#
Original context window size, required by some scaling methods (yarn, llama3, phi3, deepseek).
- Type
tp.Optional[int]
- long_factor#
Specific factor for Phi3 LongRoPE scaling (used for lengths > original).
- Type
tp.Optional[float]
- short_factor#
Specific factor for Phi3 LongRoPE scaling (used for lengths <= original).
- Type
tp.Optional[float]
- long_mscale#
Potentially used by variants like Phi3. (Not used in current get_rope).
- Type
tp.Optional[float]
- short_mscale#
Potentially used by variants like Phi3. (Not used in current get_rope).
- Type
tp.Optional[float]
- # Add other potential scaling parameters here as needed
- Type
e.g., from YaRN, Deepseek
- extrapolation_factor#
YaRN/Deepseek parameter.
- Type
tp.Optional[float]
- attn_factor#
YaRN/Deepseek parameter.
- Type
tp.Optional[float]
- beta_fast#
YaRN/Deepseek parameter.
- Type
tp.Optional[int]
- beta_slow#
YaRN/Deepseek parameter.
- Type
tp.Optional[int]
- mscale#
Deepseek parameter.
- Type
tp.Optional[float]
- mscale_all_dim#
Deepseek parameter.
- Type
tp.Optional[float]
- factor: Optional[float] = None#
- classmethod from_dict(config_dict: Dict[str, Any]) RopeConfig[source]#
Create a RopeConfig instance from a dictionary.
Handles potential alias โtypeโ for โrope_typeโ.
- Parameters
config_dict (tp.Dict[str, tp.Any]) โ Dictionary containing RoPE configuration.
- Returns
An instance populated from the dictionary.
- Return type
- from_tuple()#
- high_freq_factor: Optional[float] = None#
- items() a set-like object providing a view on D's items#
- keys() a set-like object providing a view on D's keys#
- long_factor: Optional[float] = None#
- long_mscale: Optional[float] = None#
- low_freq_factor: Optional[float] = None#
- original_max_position_embeddings: Optional[int] = None#
- replace(**kwargs)#
- rope_type: str = 'default'#
- short_factor: Optional[float] = None#
- short_mscale: Optional[float] = None#
- to_dict() Dict[str, Any][source]#
Convert the RopeConfig instance to a dictionary.
Filters out attributes with None values. The dictionary is made hashable using a custom class for potential use with JIT compilation contexts (though making the dict itself static in get_frequencies is preferred).
- Returns
A hashable dictionary containing non-None configuration values.
- Return type
tp.Dict[str, tp.Any]
- to_tuple()#
- values() an object providing a view on D's values#
- class easydel.layers.rotary_embedding.RotaryEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleStandard Rotary Positional Embedding (RoPE) module.
- head_size#
The dimension size of each attention head.
- Type
int
- rotary_dim#
The dimension size of the rotary embeddings applied. Can be <= head_size.
- Type
int
- max_position_embeddings#
The maximum sequence length the model can handle.
- Type
int
- base#
The base value for calculating frequencies.
- Type
int
- is_neox_style#
Flag indicating whether to use Neox-style rotation.
- Type
bool
- dtype#
Data type for computations.
- Type
jnp.dtype
- class easydel.layers.rotary_embedding.YaRNScalingRotaryEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
RotaryEmbeddingRotaryEmbedding extended with the YaRN (Yet another RoPE extensioN method) scaling.
Combines interpolation and extrapolation with frequency correction and magnitude scaling.
- scaling_factor#
The primary scaling factor for context length.
- Type
tp.Union[float, int]
- extrapolation_factor#
Controls the strength of extrapolation correction.
- Type
float
- attn_factor#
Scales the output attention values.
- Type
float
- beta_fast#
YaRN parameter for high-frequency dimensions correction range.
- Type
int
- beta_slow#
YaRN parameter for low-frequency dimensions correction range.
- Type
int
- Inherits other attributes from RotaryEmbedding. Note
max_position_embeddings
- in the parent init likely refers to the *original* max length for YaRN calculations.
- easydel.layers.rotary_embedding.apply_basic_rope(query: ~jax.Array, key: ~jax.Array, positions: ~jax.Array, frequencies: ~jax.Array, rotary_dim: int, is_neox_style: bool, offsets: ~jax.Array = None, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)[source]#
Applies standard or partially applied RoPE to query and key tensors.
Selects frequencies based on positions (and optional offsets), then applies the rotation using _apply_rotary_emb. Handles cases where RoPE is applied only to a subset of the head dimension (rotary_dim < query.shape[-1]).
- Parameters
query (jax.Array) โ Query tensor. Shape [โฆ, sequence_length, num_heads, head_dim].
key (jax.Array) โ Key tensor. Shape [โฆ, sequence_length, num_heads, head_dim].
positions (jax.Array) โ Array of positions for lookup in the frequency cache. Shape [sequence_length].
frequencies (jax.Array) โ Precomputed frequency cache. Shape [max_length, rotary_dim_freq].
rotary_dim (int) โ The dimension up to which RoPE is applied.
is_neox_style (bool) โ Whether to use Neox-style rotation.
offsets (jax.Array, optional) โ Optional offsets to add to positions. Defaults to None.
dtype (jnp.dtype, optional) โ Output dtype. Defaults to jnp.float32.
- Returns
The rotated query and key tensors with the specified dtype.
- Return type
- easydel.layers.rotary_embedding.apply_phi3_rope(query, key, positions, frequencies, offsets: ~jax.Array = None, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)[source]#
Applies Phi-3 LongRoPE to query and key tensors.
Uses a specific rotation application style (_rotate_neox) assumed by Phi-3.
- Parameters
query (jax.Array) โ Query tensor. Shape [batch_size, sequence_length, num_heads, head_dim].
key (jax.Array) โ Key tensor. Shape [batch_size, sequence_length, num_heads, head_dim].
positions (jax.Array) โ Array of positions. Shape [sequence_length].
frequencies (jax.Array) โ Precomputed Phi-3 frequency cache. Shape [1, max_length, rotary_dim].
offsets (jax.Array, optional) โ Optional offsets to add to positions. Defaults to None.
dtype (jnp.dtype, optional) โ Output dtype. Defaults to jnp.float32.
- Returns
The rotated query and key tensors with the specified dtype.
- Return type
- easydel.layers.rotary_embedding.compute_basic_frequencies(base: int, rotary_dim: int, max_position_embeddings: int)[source]#
Computes the basic RoPE frequencies (cos and sin values) for all positions.
- Parameters
base (int) โ The base value for the geometric progression of frequencies.
rotary_dim (int) โ The dimension of the rotary embeddings.
max_position_embeddings (int) โ The maximum sequence length.
- Returns
- A frequency cache tensor of shape
(max_position_embeddings, rotary_dim). Contains concatenated cos and sin values.
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.compute_basic_inv_frequencies(base: int, rotary_dim: int)[source]#
Computes the inverse frequencies for standard RoPE.
- Parameters
base (int) โ The base value for the geometric progression of frequencies.
rotary_dim (int) โ The dimension of the rotary embeddings.
- Returns
An array of inverse frequencies of shape (rotary_dim // 2,).
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.compute_deepseek_frequencies(base, rotary_dim, scaling_factor, extrapolation_factor, beta_fast, beta_slow, max_position_embeddings, mscale, mscale_all_dim, attn_factor) Array[source]#
Computes RoPE frequencies using the Deepseek-YaRN scaling method.
Similar to YaRN but potentially uses different mscale calculation parameters.
- Parameters
base (float) โ The base value for positional encoding.
rotary_dim (int) โ The dimension of the rotary embeddings.
scaling_factor (float) โ The factor by which the context length is scaled.
extrapolation_factor (float) โ YaRN parameter controlling extrapolation strength.
beta_fast (int) โ YaRN parameter for faster rotating dimensions.
beta_slow (int) โ YaRN parameter for slower rotating dimensions.
max_position_embeddings (int) โ Original maximum sequence length before scaling.
mscale (float) โ Parameter for yarn_get_mscale calculation.
mscale_all_dim (float) โ Parameter for yarn_get_mscale calculation.
attn_factor (float) โ Scaling factor applied to attention outputs.
- Returns
- A frequency cache tensor of shape
(max_position_embeddings * scaling_factor, rotary_dim).
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.compute_dynamic_frequencies(base: int, rotary_dim: int, max_position_embeddings: int, scaling_factor: float)[source]#
Computes RoPE frequencies using Dynamic NTK scaling.
Adjusts the โbaseโ dynamically based on the scaling factor.
- Parameters
base (int) โ The initial base value before dynamic adjustment.
rotary_dim (int) โ The dimension of the rotary embeddings.
max_position_embeddings (int) โ The base maximum sequence length before scaling.
scaling_factor (float) โ The scaling factor applied to the sequence length.
- Returns
- A frequency cache tensor of shape
(max_position_embeddings * scaling_factor, rotary_dim).
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.compute_linear_frequencies(base: int, rotary_dim: int, max_position_embeddings: int, scaling_factors: List[float])[source]#
Computes RoPE frequencies using linear scaling for potentially multiple factors.
This function computes frequency caches for each scaling factor and concatenates them. Note: This implementation seems designed for a specific use case where different parts of a sequence might use different scaling factors, determined by offsets. If only one scaling factor is used, it behaves like standard linear scaling.
- Parameters
base (int) โ The base value for the geometric progression of frequencies.
rotary_dim (int) โ The dimension of the rotary embeddings.
max_position_embeddings (int) โ The base maximum sequence length before scaling.
scaling_factors (tp.Union[tp.List[float], float]) โ A single scaling factor or a list of scaling factors.
- Returns
- A frequency cache tensor. If multiple scaling factors are provided,
the caches are concatenated along the position dimension. Shape is (total_scaled_length, rotary_dim).
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.compute_llama3_frequencies(base, rotary_dim, low_freq_factor, high_freq_factor, scaling_factor, max_position_embeddings: int)[source]#
Computes RoPE frequencies using the Llama3 scaling method.
- Parameters
base (float) โ The base value for positional encoding.
rotary_dim (int) โ The dimension of the rotary embeddings.
low_freq_factor (float) โ Factor for adjusting low-frequency components.
high_freq_factor (float) โ Factor for adjusting high-frequency components.
scaling_factor (float) โ The overall scaling factor applied.
max_position_embeddings (int) โ Original maximum sequence length (referred to as orig_max_position in compute_llama3_inv_frequencies). This defines the length of the frequency cache.
- Returns
A frequency cache tensor of shape (max_position_embeddings, rotary_dim).
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.compute_llama3_inv_frequencies(base, rotary_dim, low_freq_factor, high_freq_factor, orig_max_position, scaling_factor)[source]#
Computes the inverse frequencies for Llama3-style scaled RoPE.
Adjusts frequencies based on wavelength thresholds and a smoothing factor.
- Parameters
base (float) โ The base value for positional encoding.
rotary_dim (int) โ The dimension of the rotary embeddings.
low_freq_factor (float) โ Factor for adjusting low-frequency components.
high_freq_factor (float) โ Factor for adjusting high-frequency components.
orig_max_position (int) โ Original maximum sequence length before scaling.
scaling_factor (float) โ The overall scaling factor applied.
- Returns
An array of Llama3-adjusted inverse frequencies of shape (rotary_dim // 2,).
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.compute_phi3_frequencies(base, head_size, rotary_dim, max_position_embeddings, original_max_position_embeddings, short_factor, long_factor)[source]#
Computes RoPE frequencies using the Phi-3 LongRoPE scaling method.
Applies different scaling factors based on whether the target length is shorter or longer than the original max length. Includes a scaling factor adjustment based on the ratio of target length to original length.
- Parameters
base (float) โ The base value for positional encoding.
head_size (int) โ The dimension of each attention head.
rotary_dim (int) โ The dimension of the rotary embeddings. Must equal head_size for Phi-3.
max_position_embeddings (int) โ The target maximum sequence length after scaling.
original_max_position_embeddings (int) โ Original maximum sequence length before scaling.
short_factor (tp.List[float]) โ Scaling factors for frequencies when max_position_embeddings <= original_max_position_embeddings.
long_factor (tp.List[float]) โ Scaling factors for frequencies when max_position_embeddings > original_max_position_embeddings.
- Returns
A frequency cache tensor of shape (1, max_position_embeddings, rotary_dim).
- Return type
jnp.ndarray
- Raises
ValueError โ If rotary_dim does not equal head_size.
- easydel.layers.rotary_embedding.compute_yarn_frequencies(base: float, rotary_dim: int, beta_fast: float, beta_slow: float, max_position_embeddings: int, scaling_factor: float, extrapolation_factor: float, attn_factor: float) Array[source]#
Computes RoPE frequencies using the YaRN scaling method.
Includes adjustments based on YaRN parameters and applies an mscale factor.
- Parameters
base (float) โ The base value for positional encoding.
rotary_dim (int) โ The dimension of the rotary embeddings.
beta_fast (float) โ YaRN parameter for faster rotating dimensions.
beta_slow (float) โ YaRN parameter for slower rotating dimensions.
max_position_embeddings (int) โ Original maximum sequence length before scaling.
scaling_factor (float) โ The factor by which the context length is scaled.
extrapolation_factor (float) โ YaRN parameter controlling extrapolation strength.
attn_factor (float) โ YaRN parameter scaling the attention outputs.
- Returns
- A frequency cache tensor of shape
(max_position_embeddings * scaling_factor, rotary_dim).
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.compute_yarn_inv_frequencies(base: float, rotary_dim: int, beta_fast: float, beta_slow: float, max_position_embeddings: int, scaling_factor: float, extrapolation_factor: float) Array[source]#
Computes the inverse frequencies for YaRN scaled RoPE.
Combines interpolation and extrapolation frequencies based on correction ranges.
- Parameters
base (float) โ The base value for positional encoding.
rotary_dim (int) โ The dimension of the rotary embeddings.
beta_fast (float) โ YaRN parameter for faster rotating dimensions.
beta_slow (float) โ YaRN parameter for slower rotating dimensions.
max_position_embeddings (int) โ Original maximum sequence length before scaling.
scaling_factor (float) โ The factor by which the context length is scaled.
extrapolation_factor (float) โ YaRN parameter controlling extrapolation strength.
- Returns
An array of YaRN-adjusted inverse frequencies of shape (rotary_dim // 2,).
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.get_frequencies(head_size: int, rotary_dim: int, max_position: int, base: int, rope_scaling: Optional[Dict[str, Any]] = None, partial_rotary_factor: float = 1.0) Array[source]#
Computes and returns the RoPE frequency cache based on configuration.
Selects the appropriate frequency computation function (basic, linear, dynamic, YaRN, Llama3, Phi3, Deepseek) based on the rope_scaling dictionary. This function is JIT-compiled for performance, with relevant parameters marked static.
- Parameters
head_size (int) โ Dimension of each attention head (needed for some scaling types like Phi3).
rotary_dim (int) โ Base dimension for rotary embedding (before partial factor).
max_position (int) โ Maximum sequence length for which to compute frequencies. This might be the original or target length depending on scaling type.
base (int) โ Base value for frequency calculation.
rope_scaling (tp.Optional[tp.Dict[str, tp.Any]], optional) โ Dictionary specifying the type and parameters of RoPE scaling. Determines which frequency function to call. Defaults to None (uses compute_basic_frequencies).
partial_rotary_factor (float, optional) โ Factor to reduce the rotary dimension. Defaults to 1.0.
- Returns
- The computed frequency cache tensor. Shape depends on the scaling method,
typically [computed_length, rotary_dim_effective].
- Return type
- Raises
ValueError โ If rope_scaling specifies an unknown rope_type.
- easydel.layers.rotary_embedding.get_inv_frequencies(head_size: int, rotary_dim: int, max_position: int, base: int, rope_scaling: Optional[Dict[str, Any]] = None, partial_rotary_factor: float = 1.0) Array[source]#
Computes and returns just the inverse frequencies for RoPE based on configuration.
Similar to get_frequencies but returns only the inverse frequencies without computing the full frequency cache (no cos/sin transformation).
- Parameters
head_size (int) โ Dimension of each attention head (needed for some scaling types like Phi3).
rotary_dim (int) โ Base dimension for rotary embedding (before partial factor).
max_position (int) โ Maximum sequence length the model should support.
base (int) โ Base value for frequency calculation.
rope_scaling (tp.Optional[tp.Dict[str, tp.Any]], optional) โ Dictionary specifying the type and parameters of RoPE scaling. Determines which frequency function to call. Defaults to None (uses basic inverse frequencies).
partial_rotary_factor (float, optional) โ Factor to reduce the rotary dimension. Defaults to 1.0.
- Returns
The computed inverse frequencies. Shape is typically (rotary_dim // 2,).
- Return type
- Raises
ValueError โ If rope_scaling specifies an unknown rope_type.
- easydel.layers.rotary_embedding.get_rope(head_size: int, rotary_dim: int, max_position: int, base: int, is_neox_style: bool = True, rope_scaling: Optional[Dict[str, Any]] = None, dtype: Optional[dtype] = None, partial_rotary_factor: float = 1.0) RotaryEmbedding[source]#
Factory function to create and return a RotaryEmbedding instance based on configuration.
Selects the appropriate RoPE class (standard, linear, dynamic, YaRN, Llama3, Phi3, Deepseek) based on the rope_scaling dictionary.
- Parameters
head_size (int) โ Dimension of each attention head.
rotary_dim (int) โ Base dimension for rotary embedding (before partial factor).
max_position (int) โ Maximum sequence length the model should support (target length).
base (int) โ Base value for frequency calculation.
is_neox_style (bool, optional) โ Use Neox rotation style. Defaults to True.
rope_scaling (tp.Optional[tp.Dict[str, tp.Any]], optional) โ Dictionary specifying the type and parameters of RoPE scaling. If None or โrope_typeโ is โdefaultโ, uses standard RoPE. Keys like โrope_typeโ, โfactorโ, โoriginal_max_position_embeddingsโ, etc., are used. Defaults to None.
dtype (tp.Optional[jnp.dtype], optional) โ Data type for embeddings. Defaults to jnp.float32.
partial_rotary_factor (float, optional) โ Factor to reduce the rotary dimension (e.g., 0.5 applies RoPE to half the dimensions). Defaults to 1.0.
- Returns
An instance of the configured RotaryEmbedding subclass.
- Return type
- Raises
ValueError โ If rope_scaling specifies an unknown rope_type.
- easydel.layers.rotary_embedding.rope_wraper(type)[source]#
A decorator factory that registers a RotaryEmbedding class under a specific type name.
This allows retrieving RoPE configurations by type name later. It also sets basic __str__ and __repr__ for the decorated class.
- Parameters
type (str) โ The name to register the RoPE class under (e.g., โlinearโ, โyarnโ).
- Returns
- A decorator function that takes a RotaryEmbedding class, registers it,
and returns the class.
- Return type
Callable
- easydel.layers.rotary_embedding.yarn_get_mscale(scale: float = 1, mscale: float = 1) float[source]#
Calculates the mscale factor, potentially used by Deepseek-YaRN or similar methods.
Allows specifying an additional mscale parameter compared to _yarn_get_mscale.
- Parameters
scale (float, optional) โ The scaling factor. Defaults to 1.
mscale (float, optional) โ An additional scaling parameter. Defaults to 1.
- Returns
The calculated mscale value. Returns 1.0 if scale <= 1.
- Return type
float