easydel.layers.rotary_embedding#
- easydel.layers.rotary_embedding.AVAILABLE_ROPE_TYPES = {'deepseek_yarn': {'__abstractmethods__': frozenset({}), '__annotations__': {}, '__call__': <function DeepseekScalingRotaryEmbedding.__call__>, '__doc__': '\n RotaryEmbedding implementing a YaRN-like scaling method, potentially from Deepseek models.\n\n Uses YaRN parameters (`beta_fast`, `beta_slow`, `extrapolation_factor`) and includes\n additional m-scale parameters (`mscale`, `mscale_all_dim`). This version has a custom\n `__call__` method differing slightly from `apply_basic_rope`.\n\n Attributes:\n head_size (int): Dimension of each attention head.\n rotary_dim (int): Dimension subjected to rotary embedding.\n max_position_embeddings (int): Original maximum sequence length before scaling.\n base (int): Base for frequency calculation.\n is_neox_style (bool): Use Neox rotation if True, GPT-J otherwise.\n dtype (jnp.dtype): Data type for embeddings.\n scaling_factor (float): Primary scaling factor.\n extrapolation_factor (float): YaRN extrapolation factor.\n attn_factor (float): Attention scaling factor.\n beta_fast (int): YaRN parameter.\n beta_slow (int): YaRN parameter.\n mscale (float): Parameter for m-scale calculation.\n mscale_all_dim (float): Parameter for m-scale calculation.\n ', '__init__': <function DeepseekScalingRotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>, '_pytree__nodes': frozenset({'_pytree__state'})}, 'default': {'__abstractmethods__': frozenset({}), '__annotations__': {}, '__call__': <function RotaryEmbedding.__call__>, '__doc__': '\n Standard Rotary Positional Embedding (RoPE) module.\n\n Attributes:\n head_size (int): The dimension size of each attention head.\n rotary_dim (int): The dimension size of the rotary embeddings applied. Can be <= head_size.\n max_position_embeddings (int): The maximum sequence length the model can handle.\n base (int): The base value for calculating frequencies.\n is_neox_style (bool): Flag indicating whether to use Neox-style rotation.\n dtype (jnp.dtype): Data type for computations.\n ', '__init__': <function RotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>, '_pytree__nodes': frozenset({'_pytree__state'})}, 'dynamic': {'__abstractmethods__': frozenset({}), '__annotations__': {}, '__call__': <function DynamicNTKScalingRotaryEmbedding.__call__>, '__doc__': '\n RotaryEmbedding extended with Dynamic NTK scaling.\n\n Dynamically adjusts the `base` parameter based on the scaling factor.\n\n Attributes:\n scaling_factor (float): The scaling factor applied to sequence length and base calculation.\n Inherits other attributes from RotaryEmbedding.\n ', '__init__': <function DynamicNTKScalingRotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>, '_pytree__nodes': frozenset({'_pytree__state'})}, 'linear': {'__abstractmethods__': frozenset({}), '__annotations__': {}, '__call__': <function LinearScalingRotaryEmbedding.__call__>, '__doc__': '\n RotaryEmbedding extended with Linear Scaling.\n\n Linearly scales the position indices before calculating frequencies.\n\n Attributes:\n scaling_factors (tp.Union[tp.List[float], float]): The factor(s) to scale positions by.\n Inherits other attributes from RotaryEmbedding.\n ', '__init__': <function LinearScalingRotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>, '_pytree__nodes': frozenset({'_pytree__state'})}, 'llama3': {'__abstractmethods__': frozenset({}), '__annotations__': {}, '__call__': <function Llama3RotaryEmbedding.__call__>, '__doc__': '\n RotaryEmbedding implementing the Llama-3 scaling method.\n\n Adjusts frequencies based on wavelength thresholds (`low_freq_factor`, `high_freq_factor`)\n and applies an overall scaling factor.\n\n Attributes:\n scaling_factor (float): Overall scaling factor.\n low_freq_factor (float): Factor related to low frequency wavelength threshold.\n high_freq_factor (float): Factor related to high frequency wavelength threshold.\n orig_max_position (int): Original maximum sequence length before scaling.\n Inherits other attributes from RotaryEmbedding.\n ', '__init__': <function Llama3RotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>, '_pytree__nodes': frozenset({'_pytree__state'})}, 'longrope': {'__abstractmethods__': frozenset({}), '__annotations__': {}, '__call__': <function Phi3LongRoPEScaledRotaryEmbedding.__call__>, '__doc__': '\n RotaryEmbedding using the Phi-3 LongRoPE scaling method.\n\n Applies different frequency scaling factors (`short_factor`, `long_factor`)\n depending on the target sequence length relative to the original maximum.\n Requires `rotary_dim` to be equal to `head_size`.\n\n Attributes:\n head_size (int): Dimension of each attention head. Must equal rotary_dim.\n rotary_dim (int): Dimension subjected to rotary embedding. Must equal head_size.\n max_position_embeddings (int): The target maximum sequence length after scaling.\n original_max_position_embeddings (int): Original maximum sequence length before scaling.\n base (int): Base for frequency calculation.\n is_neox_style (bool): Flag indicating whether Neox-style rotation is assumed (used by `apply_phi3_rope`).\n dtype (jnp.dtype): Data type for computations.\n short_factor (tp.List[float]): Scaling factors applied when target length <= original max length.\n long_factor (tp.List[float]): Scaling factors applied when target length > original max length.\n ', '__init__': <function Phi3LongRoPEScaledRotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>, '_pytree__nodes': frozenset({'_pytree__state'})}, 'mrope': {'__abstractmethods__': frozenset({}), '__annotations__': {}, '__call__': <function MultiModalRotaryEmbedding.__call__>, '__doc__': 'Multi-dimensional RoPE (MRoPE) with interleaved THW layout for Qwen2/3-VL models.\n\n MRoPE (Multi-dimensional Rotary Position Embedding) extends standard RoPE to handle\n 3D position information (Temporal, Height, Width) for vision-language models.\n\n The interleaving pattern reorganizes frequencies from chunked [TTT...HHH...WWW] to\n interleaved [T₀,H₀,W₀, T₁,H₁,W₁, ...], preserving frequency continuity for each\n spatial/temporal dimension.\n\n Attributes:\n mrope_section: Tuple of (T, H, W) dimensions specifying how many frequency\n components are allocated to each dimension. Default: (24, 20, 20) for\n 64-dim rotary embeddings (128 head_dim / 2).\n attention_scaling: Post-processing scaling factor applied to cos/sin.\n Default 1.0 for standard mRoPE. Can be set for advanced RoPE types.\n ', '__init__': <function MultiModalRotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>, '_apply_interleaved_mrope': <function MultiModalRotaryEmbedding._apply_interleaved_mrope>, '_pytree__nodes': frozenset({'_pytree__state'})}, 'yarn': {'__abstractmethods__': frozenset({}), '__annotations__': {}, '__call__': <function YaRNScalingRotaryEmbedding.__call__>, '__doc__': '\n RotaryEmbedding extended with the YaRN (Yet another RoPE extensioN method) scaling.\n\n Combines interpolation and extrapolation with frequency correction and magnitude scaling.\n\n Attributes:\n scaling_factor (tp.Union[float, int]): The primary scaling factor for context length.\n extrapolation_factor (float): Controls the strength of extrapolation correction.\n attn_factor (float): Scales the output attention values.\n beta_fast (int): YaRN parameter for high-frequency dimensions correction range.\n beta_slow (int): YaRN parameter for low-frequency dimensions correction range.\n Inherits other attributes from RotaryEmbedding. Note: `max_position_embeddings`\n in the parent init likely refers to the *original* max length for YaRN calculations.\n ', '__init__': <function YaRNScalingRotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>, '_pytree__nodes': frozenset({'_pytree__state'})}}#
A dictionary to store registered RoPE (Rotary Position Embedding) types and their configurations.
- class easydel.layers.rotary_embedding.DeepseekScalingRotaryEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleRotaryEmbedding implementing a YaRN-like scaling method, potentially from Deepseek models.
Uses YaRN parameters (beta_fast, beta_slow, extrapolation_factor) and includes additional m-scale parameters (mscale, mscale_all_dim). This version has a custom __call__ method differing slightly from apply_basic_rope.
- head_size#
Dimension of each attention head.
- Type
int
- rotary_dim#
Dimension subjected to rotary embedding.
- Type
int
- max_position_embeddings#
Original maximum sequence length before scaling.
- Type
int
- base#
Base for frequency calculation.
- Type
int
- is_neox_style#
Use Neox rotation if True, GPT-J otherwise.
- Type
bool
- dtype#
Data type for embeddings.
- Type
jnp.dtype
- scaling_factor#
Primary scaling factor.
- Type
float
- extrapolation_factor#
YaRN extrapolation factor.
- Type
float
- attn_factor#
Attention scaling factor.
- Type
float
- beta_fast#
YaRN parameter.
- Type
int
- beta_slow#
YaRN parameter.
- Type
int
- mscale#
Parameter for m-scale calculation.
- Type
float
- mscale_all_dim#
Parameter for m-scale calculation.
- Type
float
- class easydel.layers.rotary_embedding.DynamicNTKScalingRotaryEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
RotaryEmbeddingRotaryEmbedding extended with Dynamic NTK scaling.
Dynamically adjusts the base parameter based on the scaling factor.
- scaling_factor#
The scaling factor applied to sequence length and base calculation.
- Type
float
- Inherits other attributes from RotaryEmbedding.
- class easydel.layers.rotary_embedding.LinearScalingRotaryEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
RotaryEmbeddingRotaryEmbedding extended with Linear Scaling.
Linearly scales the position indices before calculating frequencies.
- scaling_factors#
The factor(s) to scale positions by.
- Type
tp.Union[tp.List[float], float]
- Inherits other attributes from RotaryEmbedding.
- class easydel.layers.rotary_embedding.Llama3RotaryEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
RotaryEmbeddingRotaryEmbedding implementing the Llama-3 scaling method.
Adjusts frequencies based on wavelength thresholds (low_freq_factor, high_freq_factor) and applies an overall scaling factor.
- scaling_factor#
Overall scaling factor.
- Type
float
- low_freq_factor#
Factor related to low frequency wavelength threshold.
- Type
float
- high_freq_factor#
Factor related to high frequency wavelength threshold.
- Type
float
- orig_max_position#
Original maximum sequence length before scaling.
- Type
int
- Inherits other attributes from RotaryEmbedding.
- class easydel.layers.rotary_embedding.MultiModalRotaryEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
RotaryEmbeddingMulti-dimensional RoPE (MRoPE) with interleaved THW layout for Qwen2/3-VL models.
MRoPE (Multi-dimensional Rotary Position Embedding) extends standard RoPE to handle 3D position information (Temporal, Height, Width) for vision-language models.
The interleaving pattern reorganizes frequencies from chunked [TTT…HHH…WWW] to interleaved [T₀,H₀,W₀, T₁,H₁,W₁, …], preserving frequency continuity for each spatial/temporal dimension.
- mrope_section#
Tuple of (T, H, W) dimensions specifying how many frequency components are allocated to each dimension. Default: (24, 20, 20) for 64-dim rotary embeddings (128 head_dim / 2).
- attention_scaling#
Post-processing scaling factor applied to cos/sin. Default 1.0 for standard mRoPE. Can be set for advanced RoPE types.
- class easydel.layers.rotary_embedding.Phi3LongRoPEScaledRotaryEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleRotaryEmbedding using the Phi-3 LongRoPE scaling method.
Applies different frequency scaling factors (short_factor, long_factor) depending on the target sequence length relative to the original maximum. Requires rotary_dim to be equal to head_size.
- head_size#
Dimension of each attention head. Must equal rotary_dim.
- Type
int
- rotary_dim#
Dimension subjected to rotary embedding. Must equal head_size.
- Type
int
- max_position_embeddings#
The target maximum sequence length after scaling.
- Type
int
- original_max_position_embeddings#
Original maximum sequence length before scaling.
- Type
int
- base#
Base for frequency calculation.
- Type
int
- is_neox_style#
Flag indicating whether Neox-style rotation is assumed (used by apply_phi3_rope).
- Type
bool
- dtype#
Data type for computations.
- Type
jnp.dtype
- short_factor#
Scaling factors applied when target length <= original max length.
- Type
tp.List[float]
- long_factor#
Scaling factors applied when target length > original max length.
- Type
tp.List[float]
- class easydel.layers.rotary_embedding.RopeConfig(rope_type: str = 'default', factor: float | None = None, low_freq_factor: float | None = None, high_freq_factor: float | None = None, original_max_position_embeddings: int | None = None, long_factor: float | None = None, short_factor: float | None = None, long_mscale: float | None = None, short_mscale: float | None = None, beta_fast: int | None = None, beta_slow: int | None = None, mscale: int | None = None, mscale_all_dim: int | None = None, mrope_interleaved: bool | None = None, mrope_section: list[int] | None = None)[source]#
Bases:
MappingConfiguration class for RoPE (Rotary Position Embedding) parameters.
Stores the configuration related to RoPE type and its scaling parameters, making it easy to manage and pass around RoPE settings.
- rope_type#
The type of RoPE scaling to use (e.g., “default”, “linear”, “yarn”, “llama3”). Defaults to “default”.
- Type
str
- factor#
General scaling factor used by some types (linear, dynamic, yarn, llama3).
- Type
tp.Optional[float]
- low_freq_factor#
Specific factor for Llama3 scaling.
- Type
tp.Optional[float]
- high_freq_factor#
Specific factor for Llama3 scaling.
- Type
tp.Optional[float]
- original_max_position_embeddings#
Original context window size, required by some scaling methods (yarn, llama3, phi3, deepseek).
- Type
tp.Optional[int]
- long_factor#
Specific factor for Phi3 LongRoPE scaling (used for lengths > original).
- Type
tp.Optional[float]
- short_factor#
Specific factor for Phi3 LongRoPE scaling (used for lengths <= original).
- Type
tp.Optional[float]
- long_mscale#
Potentially used by variants like Phi3. (Not used in current get_rope).
- Type
tp.Optional[float]
- short_mscale#
Potentially used by variants like Phi3. (Not used in current get_rope).
- Type
tp.Optional[float]
- # Add other potential scaling parameters here as needed
- Type
e.g., from YaRN, Deepseek
- extrapolation_factor#
YaRN/Deepseek parameter.
- Type
tp.Optional[float]
- attn_factor#
YaRN/Deepseek parameter.
- Type
tp.Optional[float]
- beta_fast#
YaRN/Deepseek parameter.
- Type
tp.Optional[int]
- beta_slow#
YaRN/Deepseek parameter.
- Type
tp.Optional[int]
- mscale#
Deepseek parameter.
- Type
tp.Optional[float]
- mscale_all_dim#
Deepseek parameter.
- Type
tp.Optional[float]
- classmethod from_dict(config_dict: dict[str, Any]) RopeConfig[source]#
Create a RopeConfig instance from a dictionary.
Handles potential alias ‘type’ for ‘rope_type’.
- Parameters
config_dict (tp.Dict[str, tp.Any]) – Dictionary containing RoPE configuration.
- Returns
An instance populated from the dictionary.
- Return type
- from_tuple()#
- items() a set-like object providing a view on D's items#
- keys() a set-like object providing a view on D's keys#
- replace(**kwargs)#
- rope_type: str = 'default'#
- to_dict() dict[str, Any][source]#
Convert the RopeConfig instance to a dictionary.
Filters out attributes with None values. The dictionary is made hashable using a custom class for potential use with JIT compilation contexts (though making the dict itself static in get_frequencies is preferred).
- Returns
A hashable dictionary containing non-None configuration values.
- Return type
tp.Dict[str, tp.Any]
- to_tuple()#
- values() an object providing a view on D's values#
- class easydel.layers.rotary_embedding.RotaryEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleStandard Rotary Positional Embedding (RoPE) module.
- head_size#
The dimension size of each attention head.
- Type
int
- rotary_dim#
The dimension size of the rotary embeddings applied. Can be <= head_size.
- Type
int
- max_position_embeddings#
The maximum sequence length the model can handle.
- Type
int
- base#
The base value for calculating frequencies.
- Type
int
- is_neox_style#
Flag indicating whether to use Neox-style rotation.
- Type
bool
- dtype#
Data type for computations.
- Type
jnp.dtype
- class easydel.layers.rotary_embedding.YaRNScalingRotaryEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
RotaryEmbeddingRotaryEmbedding extended with the YaRN (Yet another RoPE extensioN method) scaling.
Combines interpolation and extrapolation with frequency correction and magnitude scaling.
- scaling_factor#
The primary scaling factor for context length.
- Type
tp.Union[float, int]
- extrapolation_factor#
Controls the strength of extrapolation correction.
- Type
float
- attn_factor#
Scales the output attention values.
- Type
float
- beta_fast#
YaRN parameter for high-frequency dimensions correction range.
- Type
int
- beta_slow#
YaRN parameter for low-frequency dimensions correction range.
- Type
int
- Inherits other attributes from RotaryEmbedding. Note
max_position_embeddings
- in the parent init likely refers to the *original* max length for YaRN calculations.
- easydel.layers.rotary_embedding.apply_basic_rope(query: ~jax.Array, key: ~jax.Array, positions: ~jax.Array, frequencies: ~jax.Array, rotary_dim: int, is_neox_style: bool, offsets: ~jax.Array = None, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)[source]#
Applies standard or partially applied RoPE to query and key tensors.
Selects frequencies based on positions (and optional offsets), then applies the rotation using _apply_rotary_emb. Handles cases where RoPE is applied only to a subset of the head dimension (rotary_dim < query.shape[-1]).
- Parameters
query (jax.Array) – Query tensor. Shape […, sequence_length, num_heads, head_dim].
key (jax.Array) – Key tensor. Shape […, sequence_length, num_heads, head_dim].
positions (jax.Array) – Array of positions for lookup in the frequency cache. Shape [sequence_length].
frequencies (jax.Array) – Precomputed frequency cache. Shape [max_length, rotary_dim_freq].
rotary_dim (int) – The dimension up to which RoPE is applied.
is_neox_style (bool) – Whether to use Neox-style rotation.
offsets (jax.Array, optional) – Optional offsets to add to positions. Defaults to None.
dtype (jnp.dtype, optional) – Output dtype. Defaults to jnp.float32.
- Returns
The rotated query and key tensors with the specified dtype.
- Return type
- easydel.layers.rotary_embedding.apply_phi3_rope(query, key, positions, frequencies, offsets: ~jax.Array = None, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)[source]#
Applies Phi-3 LongRoPE to query and key tensors.
Uses a specific rotation application style (_rotate_neox) assumed by Phi-3.
- Parameters
query (jax.Array) – Query tensor. Shape [batch_size, sequence_length, num_heads, head_dim].
key (jax.Array) – Key tensor. Shape [batch_size, sequence_length, num_heads, head_dim].
positions (jax.Array) – Array of positions. Shape [sequence_length].
frequencies (jax.Array) – Precomputed Phi-3 frequency cache. Shape [1, max_length, rotary_dim].
offsets (jax.Array, optional) – Optional offsets to add to positions. Defaults to None.
dtype (jnp.dtype, optional) – Output dtype. Defaults to jnp.float32.
- Returns
The rotated query and key tensors with the specified dtype.
- Return type
- easydel.layers.rotary_embedding.compute_basic_frequencies(base: int, rotary_dim: int, max_position_embeddings: int)[source]#
Computes the basic RoPE frequencies (cos and sin values) for all positions.
- Parameters
base (int) – The base value for the geometric progression of frequencies.
rotary_dim (int) – The dimension of the rotary embeddings.
max_position_embeddings (int) – The maximum sequence length.
- Returns
- A frequency cache tensor of shape
(max_position_embeddings, rotary_dim). Contains concatenated cos and sin values.
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.compute_basic_inv_frequencies(base: int, rotary_dim: int)[source]#
Computes the inverse frequencies for standard RoPE.
- Parameters
base (int) – The base value for the geometric progression of frequencies.
rotary_dim (int) – The dimension of the rotary embeddings.
- Returns
An array of inverse frequencies of shape (rotary_dim // 2,).
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.compute_deepseek_frequencies(base, rotary_dim, scaling_factor, extrapolation_factor, beta_fast, beta_slow, max_position_embeddings, mscale, mscale_all_dim, attn_factor) Array[source]#
Computes RoPE frequencies using the Deepseek-YaRN scaling method.
Similar to YaRN but potentially uses different mscale calculation parameters.
- Parameters
base (float) – The base value for positional encoding.
rotary_dim (int) – The dimension of the rotary embeddings.
scaling_factor (float) – The factor by which the context length is scaled.
extrapolation_factor (float) – YaRN parameter controlling extrapolation strength.
beta_fast (int) – YaRN parameter for faster rotating dimensions.
beta_slow (int) – YaRN parameter for slower rotating dimensions.
max_position_embeddings (int) – Original maximum sequence length before scaling.
mscale (float) – Parameter for yarn_get_mscale calculation.
mscale_all_dim (float) – Parameter for yarn_get_mscale calculation.
attn_factor (float) – Scaling factor applied to attention outputs.
- Returns
- A frequency cache tensor of shape
(max_position_embeddings * scaling_factor, rotary_dim).
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.compute_dynamic_frequencies(base: int, rotary_dim: int, max_position_embeddings: int, scaling_factor: float)[source]#
Computes RoPE frequencies using Dynamic NTK scaling.
Adjusts the ‘base’ dynamically based on the scaling factor.
- Parameters
base (int) – The initial base value before dynamic adjustment.
rotary_dim (int) – The dimension of the rotary embeddings.
max_position_embeddings (int) – The base maximum sequence length before scaling.
scaling_factor (float) – The scaling factor applied to the sequence length.
- Returns
- A frequency cache tensor of shape
(max_position_embeddings * scaling_factor, rotary_dim).
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.compute_linear_frequencies(base: int, rotary_dim: int, max_position_embeddings: int, scaling_factors: list[float])[source]#
Computes RoPE frequencies using linear scaling for potentially multiple factors.
This function computes frequency caches for each scaling factor and concatenates them. Note: This implementation seems designed for a specific use case where different parts of a sequence might use different scaling factors, determined by offsets. If only one scaling factor is used, it behaves like standard linear scaling.
- Parameters
base (int) – The base value for the geometric progression of frequencies.
rotary_dim (int) – The dimension of the rotary embeddings.
max_position_embeddings (int) – The base maximum sequence length before scaling.
scaling_factors (tp.Union[tp.List[float], float]) – A single scaling factor or a list of scaling factors.
- Returns
- A frequency cache tensor. If multiple scaling factors are provided,
the caches are concatenated along the position dimension. Shape is (total_scaled_length, rotary_dim).
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.compute_llama3_frequencies(base, rotary_dim, low_freq_factor, high_freq_factor, scaling_factor, max_position_embeddings: int)[source]#
Computes RoPE frequencies using the Llama3 scaling method.
- Parameters
base (float) – The base value for positional encoding.
rotary_dim (int) – The dimension of the rotary embeddings.
low_freq_factor (float) – Factor for adjusting low-frequency components.
high_freq_factor (float) – Factor for adjusting high-frequency components.
scaling_factor (float) – The overall scaling factor applied.
max_position_embeddings (int) – Original maximum sequence length (referred to as orig_max_position in compute_llama3_inv_frequencies). This defines the length of the frequency cache.
- Returns
A frequency cache tensor of shape (max_position_embeddings, rotary_dim).
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.compute_llama3_inv_frequencies(base, rotary_dim, low_freq_factor, high_freq_factor, orig_max_position, scaling_factor)[source]#
Computes the inverse frequencies for Llama3-style scaled RoPE.
Adjusts frequencies based on wavelength thresholds and a smoothing factor.
- Parameters
base (float) – The base value for positional encoding.
rotary_dim (int) – The dimension of the rotary embeddings.
low_freq_factor (float) – Factor for adjusting low-frequency components.
high_freq_factor (float) – Factor for adjusting high-frequency components.
orig_max_position (int) – Original maximum sequence length before scaling.
scaling_factor (float) – The overall scaling factor applied.
- Returns
An array of Llama3-adjusted inverse frequencies of shape (rotary_dim // 2,).
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.compute_phi3_frequencies(base, head_size, rotary_dim, max_position_embeddings, original_max_position_embeddings, short_factor, long_factor)[source]#
Computes RoPE frequencies using the Phi-3 LongRoPE scaling method.
Applies different scaling factors based on whether the target length is shorter or longer than the original max length. Includes a scaling factor adjustment based on the ratio of target length to original length.
- Parameters
base (float) – The base value for positional encoding.
head_size (int) – The dimension of each attention head.
rotary_dim (int) – The dimension of the rotary embeddings. Must equal head_size for Phi-3.
max_position_embeddings (int) – The target maximum sequence length after scaling.
original_max_position_embeddings (int) – Original maximum sequence length before scaling.
short_factor (tp.List[float]) – Scaling factors for frequencies when max_position_embeddings <= original_max_position_embeddings.
long_factor (tp.List[float]) – Scaling factors for frequencies when max_position_embeddings > original_max_position_embeddings.
- Returns
A frequency cache tensor of shape (1, max_position_embeddings, rotary_dim).
- Return type
jnp.ndarray
- Raises
ValueError – If rotary_dim does not equal head_size.
- easydel.layers.rotary_embedding.compute_yarn_frequencies(base: float, rotary_dim: int, beta_fast: float, beta_slow: float, max_position_embeddings: int, scaling_factor: float, extrapolation_factor: float, attn_factor: float) Array[source]#
Computes RoPE frequencies using the YaRN scaling method.
Includes adjustments based on YaRN parameters and applies an mscale factor.
- Parameters
base (float) – The base value for positional encoding.
rotary_dim (int) – The dimension of the rotary embeddings.
beta_fast (float) – YaRN parameter for faster rotating dimensions.
beta_slow (float) – YaRN parameter for slower rotating dimensions.
max_position_embeddings (int) – Original maximum sequence length before scaling.
scaling_factor (float) – The factor by which the context length is scaled.
extrapolation_factor (float) – YaRN parameter controlling extrapolation strength.
attn_factor (float) – YaRN parameter scaling the attention outputs.
- Returns
- A frequency cache tensor of shape
(max_position_embeddings * scaling_factor, rotary_dim).
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.compute_yarn_inv_frequencies(base: float, rotary_dim: int, beta_fast: float, beta_slow: float, max_position_embeddings: int, scaling_factor: float, extrapolation_factor: float) Array[source]#
Computes the inverse frequencies for YaRN scaled RoPE.
Combines interpolation and extrapolation frequencies based on correction ranges.
- Parameters
base (float) – The base value for positional encoding.
rotary_dim (int) – The dimension of the rotary embeddings.
beta_fast (float) – YaRN parameter for faster rotating dimensions.
beta_slow (float) – YaRN parameter for slower rotating dimensions.
max_position_embeddings (int) – Original maximum sequence length before scaling.
scaling_factor (float) – The factor by which the context length is scaled.
extrapolation_factor (float) – YaRN parameter controlling extrapolation strength.
- Returns
An array of YaRN-adjusted inverse frequencies of shape (rotary_dim // 2,).
- Return type
jnp.ndarray
- easydel.layers.rotary_embedding.get_frequencies(head_size: int, rotary_dim: int, max_position: int, base: int, rope_scaling: dict[str, Any] | None = None, partial_rotary_factor: float = 1.0) Array[source]#
Computes and returns the RoPE frequency cache based on configuration.
Selects the appropriate frequency computation function (basic, linear, dynamic, YaRN, Llama3, Phi3, Deepseek) based on the rope_scaling dictionary. This function is JIT-compiled for performance, with relevant parameters marked static.
- Parameters
head_size (int) – Dimension of each attention head (needed for some scaling types like Phi3).
rotary_dim (int) – Base dimension for rotary embedding (before partial factor).
max_position (int) – Maximum sequence length for which to compute frequencies. This might be the original or target length depending on scaling type.
base (int) – Base value for frequency calculation.
rope_scaling (tp.Optional[tp.Dict[str, tp.Any]], optional) – Dictionary specifying the type and parameters of RoPE scaling. Determines which frequency function to call. Defaults to None (uses compute_basic_frequencies).
partial_rotary_factor (float, optional) – Factor to reduce the rotary dimension. Defaults to 1.0.
- Returns
- The computed frequency cache tensor. Shape depends on the scaling method,
typically [computed_length, rotary_dim_effective].
- Return type
- Raises
ValueError – If rope_scaling specifies an unknown rope_type.
- easydel.layers.rotary_embedding.get_inv_frequencies(head_size: int, rotary_dim: int, max_position: int, base: int, rope_scaling: dict[str, Any] | None = None, partial_rotary_factor: float = 1.0) Array[source]#
Computes and returns just the inverse frequencies for RoPE based on configuration.
Similar to get_frequencies but returns only the inverse frequencies without computing the full frequency cache (no cos/sin transformation).
- Parameters
head_size (int) – Dimension of each attention head (needed for some scaling types like Phi3).
rotary_dim (int) – Base dimension for rotary embedding (before partial factor).
max_position (int) – Maximum sequence length the model should support.
base (int) – Base value for frequency calculation.
rope_scaling (tp.Optional[tp.Dict[str, tp.Any]], optional) – Dictionary specifying the type and parameters of RoPE scaling. Determines which frequency function to call. Defaults to None (uses basic inverse frequencies).
partial_rotary_factor (float, optional) – Factor to reduce the rotary dimension. Defaults to 1.0.
- Returns
The computed inverse frequencies. Shape is typically (rotary_dim // 2,).
- Return type
- Raises
ValueError – If rope_scaling specifies an unknown rope_type.
- easydel.layers.rotary_embedding.get_rope(head_size: int, rotary_dim: int, max_position: int, base: int, is_neox_style: bool = True, rope_scaling: dict[str, Any] | None = None, dtype: numpy.dtype | None = None, partial_rotary_factor: float = 1.0) RotaryEmbedding[source]#
Factory function to create and return a RotaryEmbedding instance based on configuration.
Selects the appropriate RoPE class (standard, linear, dynamic, YaRN, Llama3, Phi3, Deepseek) based on the rope_scaling dictionary.
- Parameters
head_size (int) – Dimension of each attention head.
rotary_dim (int) – Base dimension for rotary embedding (before partial factor).
max_position (int) – Maximum sequence length the model should support (target length).
base (int) – Base value for frequency calculation.
is_neox_style (bool, optional) – Use Neox rotation style. Defaults to True.
rope_scaling (tp.Optional[tp.Dict[str, tp.Any]], optional) – Dictionary specifying the type and parameters of RoPE scaling. If None or ‘rope_type’ is ‘default’, uses standard RoPE. Keys like ‘rope_type’, ‘factor’, ‘original_max_position_embeddings’, etc., are used. Defaults to None.
dtype (tp.Optional[jnp.dtype], optional) – Data type for embeddings. Defaults to jnp.float32.
partial_rotary_factor (float, optional) – Factor to reduce the rotary dimension (e.g., 0.5 applies RoPE to half the dimensions). Defaults to 1.0.
- Returns
An instance of the configured RotaryEmbedding subclass.
- Return type
- Raises
ValueError – If rope_scaling specifies an unknown rope_type.
- easydel.layers.rotary_embedding.rope_wraper(type)[source]#
A decorator factory that registers a RotaryEmbedding class under a specific type name.
This allows retrieving RoPE configurations by type name later. It also sets basic __str__ and __repr__ for the decorated class.
- Parameters
type (str) – The name to register the RoPE class under (e.g., “linear”, “yarn”).
- Returns
- A decorator function that takes a RotaryEmbedding class, registers it,
and returns the class.
- Return type
Callable
- easydel.layers.rotary_embedding.yarn_get_mscale(scale: float = 1, mscale: float = 1) float[source]#
Calculates the mscale factor, potentially used by Deepseek-YaRN or similar methods.
Allows specifying an additional mscale parameter compared to _yarn_get_mscale.
- Parameters
scale (float, optional) – The scaling factor. Defaults to 1.
mscale (float, optional) – An additional scaling parameter. Defaults to 1.
- Returns
The calculated mscale value. Returns 1.0 if scale <= 1.
- Return type
float