easydel.layers.rotary_embedding

Contents

easydel.layers.rotary_embedding#

easydel.layers.rotary_embedding.AVAILABLE_ROPE_TYPES = {'deepseek_yarn': {'__abstractmethods__': frozenset({}), '__call__': <function DeepseekScalingRotaryEmbedding.__call__>, '__doc__': '\n\tRotaryEmbedding implementing a YaRN-like scaling method, potentially from Deepseek models.\n\n\tUses YaRN parameters (`beta_fast`, `beta_slow`, `extrapolation_factor`) and includes\n\tadditional m-scale parameters (`mscale`, `mscale_all_dim`). This version has a custom\n\t`__call__` method differing slightly from `apply_basic_rope`.\n\n\tAttributes:\n\t    head_size (int): Dimension of each attention head.\n\t    rotary_dim (int): Dimension subjected to rotary embedding.\n\t    max_position_embeddings (int): Original maximum sequence length before scaling.\n\t    base (int): Base for frequency calculation.\n\t    is_neox_style (bool): Use Neox rotation if True, GPT-J otherwise.\n\t    dtype (jnp.dtype): Data type for embeddings.\n\t    scaling_factor (float): Primary scaling factor.\n\t    extrapolation_factor (float): YaRN extrapolation factor.\n\t    attn_factor (float): Attention scaling factor.\n\t    beta_fast (int): YaRN parameter.\n\t    beta_slow (int): YaRN parameter.\n\t    mscale (float): Parameter for m-scale calculation.\n\t    mscale_all_dim (float): Parameter for m-scale calculation.\n\t', '__init__': <function DeepseekScalingRotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>}, 'default': {'__abstractmethods__': frozenset({}), '__call__': <function RotaryEmbedding.__call__>, '__doc__': '\n\tStandard Rotary Positional Embedding (RoPE) module.\n\n\tAttributes:\n\t    head_size (int): The dimension size of each attention head.\n\t    rotary_dim (int): The dimension size of the rotary embeddings applied. Can be <= head_size.\n\t    max_position_embeddings (int): The maximum sequence length the model can handle.\n\t    base (int): The base value for calculating frequencies.\n\t    is_neox_style (bool): Flag indicating whether to use Neox-style rotation.\n\t    dtype (jnp.dtype): Data type for computations.\n\t', '__init__': <function RotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>}, 'dynamic': {'__abstractmethods__': frozenset({}), '__call__': <function DynamicNTKScalingRotaryEmbedding.__call__>, '__doc__': '\n\tRotaryEmbedding extended with Dynamic NTK scaling.\n\n\tDynamically adjusts the `base` parameter based on the scaling factor.\n\n\tAttributes:\n\t    scaling_factor (float): The scaling factor applied to sequence length and base calculation.\n\t    Inherits other attributes from RotaryEmbedding.\n\t', '__init__': <function DynamicNTKScalingRotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>}, 'linear': {'__abstractmethods__': frozenset({}), '__call__': <function LinearScalingRotaryEmbedding.__call__>, '__doc__': '\n\tRotaryEmbedding extended with Linear Scaling.\n\n\tLinearly scales the position indices before calculating frequencies.\n\n\tAttributes:\n\t    scaling_factors (tp.Union[tp.List[float], float]): The factor(s) to scale positions by.\n\t    Inherits other attributes from RotaryEmbedding.\n\t', '__init__': <function LinearScalingRotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>}, 'llama3': {'__abstractmethods__': frozenset({}), '__call__': <function Llama3RotaryEmbedding.__call__>, '__doc__': '\n\tRotaryEmbedding implementing the Llama-3 scaling method.\n\n\tAdjusts frequencies based on wavelength thresholds (`low_freq_factor`, `high_freq_factor`)\n\tand applies an overall scaling factor.\n\n\tAttributes:\n\t    scaling_factor (float): Overall scaling factor.\n\t    low_freq_factor (float): Factor related to low frequency wavelength threshold.\n\t    high_freq_factor (float): Factor related to high frequency wavelength threshold.\n\t    orig_max_position (int): Original maximum sequence length before scaling.\n\t    Inherits other attributes from RotaryEmbedding.\n\t', '__init__': <function Llama3RotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>}, 'longrope': {'__abstractmethods__': frozenset({}), '__call__': <function Phi3LongRoPEScaledRotaryEmbedding.__call__>, '__doc__': '\n\tRotaryEmbedding using the Phi-3 LongRoPE scaling method.\n\n\tApplies different frequency scaling factors (`short_factor`, `long_factor`)\n\tdepending on the target sequence length relative to the original maximum.\n\tRequires `rotary_dim` to be equal to `head_size`.\n\n\tAttributes:\n\t    head_size (int): Dimension of each attention head. Must equal rotary_dim.\n\t    rotary_dim (int): Dimension subjected to rotary embedding. Must equal head_size.\n\t    max_position_embeddings (int): The target maximum sequence length after scaling.\n\t    original_max_position_embeddings (int): Original maximum sequence length before scaling.\n\t    base (int): Base for frequency calculation.\n\t    is_neox_style (bool): Flag indicating whether Neox-style rotation is assumed (used by `apply_phi3_rope`).\n\t    dtype (jnp.dtype): Data type for computations.\n\t    short_factor (tp.List[float]): Scaling factors applied when target length <= original max length.\n\t    long_factor (tp.List[float]): Scaling factors applied when target length > original max length.\n\t', '__init__': <function Phi3LongRoPEScaledRotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>}, 'yarn': {'__abstractmethods__': frozenset({}), '__call__': <function YaRNScalingRotaryEmbedding.__call__>, '__doc__': '\n\tRotaryEmbedding extended with the YaRN (Yet another RoPE extensioN method) scaling.\n\n\tCombines interpolation and extrapolation with frequency correction and magnitude scaling.\n\n\tAttributes:\n\t    scaling_factor (tp.Union[float, int]): The primary scaling factor for context length.\n\t    extrapolation_factor (float): Controls the strength of extrapolation correction.\n\t    attn_factor (float): Scales the output attention values.\n\t    beta_fast (int): YaRN parameter for high-frequency dimensions correction range.\n\t    beta_slow (int): YaRN parameter for low-frequency dimensions correction range.\n\t    Inherits other attributes from RotaryEmbedding. Note: `max_position_embeddings`\n\t    in the parent init likely refers to the *original* max length for YaRN calculations.\n\t', '__init__': <function YaRNScalingRotaryEmbedding.__init__>, '__module__': 'easydel.layers.rotary_embedding', '_abc_impl': <_abc._abc_data object>}}#

A dictionary to store registered RoPE (Rotary Position Embedding) types and their configurations.

class easydel.layers.rotary_embedding.DeepseekScalingRotaryEmbedding(*args: Any, **kwargs: Any)[source]#

Bases: Module

RotaryEmbedding implementing a YaRN-like scaling method, potentially from Deepseek models.

Uses YaRN parameters (beta_fast, beta_slow, extrapolation_factor) and includes additional m-scale parameters (mscale, mscale_all_dim). This version has a custom __call__ method differing slightly from apply_basic_rope.

head_size#

Dimension of each attention head.

Type

int

rotary_dim#

Dimension subjected to rotary embedding.

Type

int

max_position_embeddings#

Original maximum sequence length before scaling.

Type

int

base#

Base for frequency calculation.

Type

int

is_neox_style#

Use Neox rotation if True, GPT-J otherwise.

Type

bool

dtype#

Data type for embeddings.

Type

jnp.dtype

scaling_factor#

Primary scaling factor.

Type

float

extrapolation_factor#

YaRN extrapolation factor.

Type

float

attn_factor#

Attention scaling factor.

Type

float

beta_fast#

YaRN parameter.

Type

int

beta_slow#

YaRN parameter.

Type

int

mscale#

Parameter for m-scale calculation.

Type

float

mscale_all_dim#

Parameter for m-scale calculation.

Type

float

class easydel.layers.rotary_embedding.DynamicNTKScalingRotaryEmbedding(*args: Any, **kwargs: Any)[source]#

Bases: RotaryEmbedding

RotaryEmbedding extended with Dynamic NTK scaling.

Dynamically adjusts the base parameter based on the scaling factor.

scaling_factor#

The scaling factor applied to sequence length and base calculation.

Type

float

Inherits other attributes from RotaryEmbedding.
class easydel.layers.rotary_embedding.LinearScalingRotaryEmbedding(*args: Any, **kwargs: Any)[source]#

Bases: RotaryEmbedding

RotaryEmbedding extended with Linear Scaling.

Linearly scales the position indices before calculating frequencies.

scaling_factors#

The factor(s) to scale positions by.

Type

tp.Union[tp.List[float], float]

Inherits other attributes from RotaryEmbedding.
class easydel.layers.rotary_embedding.Llama3RotaryEmbedding(*args: Any, **kwargs: Any)[source]#

Bases: RotaryEmbedding

RotaryEmbedding implementing the Llama-3 scaling method.

Adjusts frequencies based on wavelength thresholds (low_freq_factor, high_freq_factor) and applies an overall scaling factor.

scaling_factor#

Overall scaling factor.

Type

float

low_freq_factor#

Factor related to low frequency wavelength threshold.

Type

float

high_freq_factor#

Factor related to high frequency wavelength threshold.

Type

float

orig_max_position#

Original maximum sequence length before scaling.

Type

int

Inherits other attributes from RotaryEmbedding.
class easydel.layers.rotary_embedding.Phi3LongRoPEScaledRotaryEmbedding(*args: Any, **kwargs: Any)[source]#

Bases: Module

RotaryEmbedding using the Phi-3 LongRoPE scaling method.

Applies different frequency scaling factors (short_factor, long_factor) depending on the target sequence length relative to the original maximum. Requires rotary_dim to be equal to head_size.

head_size#

Dimension of each attention head. Must equal rotary_dim.

Type

int

rotary_dim#

Dimension subjected to rotary embedding. Must equal head_size.

Type

int

max_position_embeddings#

The target maximum sequence length after scaling.

Type

int

original_max_position_embeddings#

Original maximum sequence length before scaling.

Type

int

base#

Base for frequency calculation.

Type

int

is_neox_style#

Flag indicating whether Neox-style rotation is assumed (used by apply_phi3_rope).

Type

bool

dtype#

Data type for computations.

Type

jnp.dtype

short_factor#

Scaling factors applied when target length <= original max length.

Type

tp.List[float]

long_factor#

Scaling factors applied when target length > original max length.

Type

tp.List[float]

class easydel.layers.rotary_embedding.RopeConfig(rope_type: str = 'default', factor: Optional[float] = None, low_freq_factor: Optional[float] = None, high_freq_factor: Optional[float] = None, original_max_position_embeddings: Optional[int] = None, long_factor: Optional[float] = None, short_factor: Optional[float] = None, long_mscale: Optional[float] = None, short_mscale: Optional[float] = None)[source]#

Bases: Mapping

Configuration class for RoPE (Rotary Position Embedding) parameters.

Stores the configuration related to RoPE type and its scaling parameters, making it easy to manage and pass around RoPE settings.

rope_type#

The type of RoPE scaling to use (e.g., โ€œdefaultโ€, โ€œlinearโ€, โ€œyarnโ€, โ€œllama3โ€). Defaults to โ€œdefaultโ€.

Type

str

factor#

General scaling factor used by some types (linear, dynamic, yarn, llama3).

Type

tp.Optional[float]

low_freq_factor#

Specific factor for Llama3 scaling.

Type

tp.Optional[float]

high_freq_factor#

Specific factor for Llama3 scaling.

Type

tp.Optional[float]

original_max_position_embeddings#

Original context window size, required by some scaling methods (yarn, llama3, phi3, deepseek).

Type

tp.Optional[int]

long_factor#

Specific factor for Phi3 LongRoPE scaling (used for lengths > original).

Type

tp.Optional[float]

short_factor#

Specific factor for Phi3 LongRoPE scaling (used for lengths <= original).

Type

tp.Optional[float]

long_mscale#

Potentially used by variants like Phi3. (Not used in current get_rope).

Type

tp.Optional[float]

short_mscale#

Potentially used by variants like Phi3. (Not used in current get_rope).

Type

tp.Optional[float]

# Add other potential scaling parameters here as needed
Type

e.g., from YaRN, Deepseek

extrapolation_factor#

YaRN/Deepseek parameter.

Type

tp.Optional[float]

attn_factor#

YaRN/Deepseek parameter.

Type

tp.Optional[float]

beta_fast#

YaRN/Deepseek parameter.

Type

tp.Optional[int]

beta_slow#

YaRN/Deepseek parameter.

Type

tp.Optional[int]

mscale#

Deepseek parameter.

Type

tp.Optional[float]

mscale_all_dim#

Deepseek parameter.

Type

tp.Optional[float]

factor: Optional[float] = None#
classmethod from_dict(config_dict: Dict[str, Any]) RopeConfig[source]#

Create a RopeConfig instance from a dictionary.

Handles potential alias โ€˜typeโ€™ for โ€˜rope_typeโ€™.

Parameters

config_dict (tp.Dict[str, tp.Any]) โ€“ Dictionary containing RoPE configuration.

Returns

An instance populated from the dictionary.

Return type

RopeConfig

from_tuple()#
high_freq_factor: Optional[float] = None#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
long_factor: Optional[float] = None#
long_mscale: Optional[float] = None#
low_freq_factor: Optional[float] = None#
original_max_position_embeddings: Optional[int] = None#
replace(**kwargs)#
rope_type: str = 'default'#
short_factor: Optional[float] = None#
short_mscale: Optional[float] = None#
to_dict() Dict[str, Any][source]#

Convert the RopeConfig instance to a dictionary.

Filters out attributes with None values. The dictionary is made hashable using a custom class for potential use with JIT compilation contexts (though making the dict itself static in get_frequencies is preferred).

Returns

A hashable dictionary containing non-None configuration values.

Return type

tp.Dict[str, tp.Any]

to_tuple()#
values() an object providing a view on D's values#
class easydel.layers.rotary_embedding.RotaryEmbedding(*args: Any, **kwargs: Any)[source]#

Bases: Module

Standard Rotary Positional Embedding (RoPE) module.

head_size#

The dimension size of each attention head.

Type

int

rotary_dim#

The dimension size of the rotary embeddings applied. Can be <= head_size.

Type

int

max_position_embeddings#

The maximum sequence length the model can handle.

Type

int

base#

The base value for calculating frequencies.

Type

int

is_neox_style#

Flag indicating whether to use Neox-style rotation.

Type

bool

dtype#

Data type for computations.

Type

jnp.dtype

class easydel.layers.rotary_embedding.YaRNScalingRotaryEmbedding(*args: Any, **kwargs: Any)[source]#

Bases: RotaryEmbedding

RotaryEmbedding extended with the YaRN (Yet another RoPE extensioN method) scaling.

Combines interpolation and extrapolation with frequency correction and magnitude scaling.

scaling_factor#

The primary scaling factor for context length.

Type

tp.Union[float, int]

extrapolation_factor#

Controls the strength of extrapolation correction.

Type

float

attn_factor#

Scales the output attention values.

Type

float

beta_fast#

YaRN parameter for high-frequency dimensions correction range.

Type

int

beta_slow#

YaRN parameter for low-frequency dimensions correction range.

Type

int

Inherits other attributes from RotaryEmbedding. Note

max_position_embeddings

in the parent init likely refers to the *original* max length for YaRN calculations.
easydel.layers.rotary_embedding.apply_basic_rope(query: ~jax.Array, key: ~jax.Array, positions: ~jax.Array, frequencies: ~jax.Array, rotary_dim: int, is_neox_style: bool, offsets: ~jax.Array = None, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)[source]#

Applies standard or partially applied RoPE to query and key tensors.

Selects frequencies based on positions (and optional offsets), then applies the rotation using _apply_rotary_emb. Handles cases where RoPE is applied only to a subset of the head dimension (rotary_dim < query.shape[-1]).

Parameters
  • query (jax.Array) โ€“ Query tensor. Shape [โ€ฆ, sequence_length, num_heads, head_dim].

  • key (jax.Array) โ€“ Key tensor. Shape [โ€ฆ, sequence_length, num_heads, head_dim].

  • positions (jax.Array) โ€“ Array of positions for lookup in the frequency cache. Shape [sequence_length].

  • frequencies (jax.Array) โ€“ Precomputed frequency cache. Shape [max_length, rotary_dim_freq].

  • rotary_dim (int) โ€“ The dimension up to which RoPE is applied.

  • is_neox_style (bool) โ€“ Whether to use Neox-style rotation.

  • offsets (jax.Array, optional) โ€“ Optional offsets to add to positions. Defaults to None.

  • dtype (jnp.dtype, optional) โ€“ Output dtype. Defaults to jnp.float32.

Returns

The rotated query and key tensors with the specified dtype.

Return type

tp.Tuple[jax.Array, jax.Array]

easydel.layers.rotary_embedding.apply_phi3_rope(query, key, positions, frequencies, offsets: ~jax.Array = None, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)[source]#

Applies Phi-3 LongRoPE to query and key tensors.

Uses a specific rotation application style (_rotate_neox) assumed by Phi-3.

Parameters
  • query (jax.Array) โ€“ Query tensor. Shape [batch_size, sequence_length, num_heads, head_dim].

  • key (jax.Array) โ€“ Key tensor. Shape [batch_size, sequence_length, num_heads, head_dim].

  • positions (jax.Array) โ€“ Array of positions. Shape [sequence_length].

  • frequencies (jax.Array) โ€“ Precomputed Phi-3 frequency cache. Shape [1, max_length, rotary_dim].

  • offsets (jax.Array, optional) โ€“ Optional offsets to add to positions. Defaults to None.

  • dtype (jnp.dtype, optional) โ€“ Output dtype. Defaults to jnp.float32.

Returns

The rotated query and key tensors with the specified dtype.

Return type

tp.Tuple[jax.Array, jax.Array]

easydel.layers.rotary_embedding.compute_basic_frequencies(base: int, rotary_dim: int, max_position_embeddings: int)[source]#

Computes the basic RoPE frequencies (cos and sin values) for all positions.

Parameters
  • base (int) โ€“ The base value for the geometric progression of frequencies.

  • rotary_dim (int) โ€“ The dimension of the rotary embeddings.

  • max_position_embeddings (int) โ€“ The maximum sequence length.

Returns

A frequency cache tensor of shape

(max_position_embeddings, rotary_dim). Contains concatenated cos and sin values.

Return type

jnp.ndarray

easydel.layers.rotary_embedding.compute_basic_inv_frequencies(base: int, rotary_dim: int)[source]#

Computes the inverse frequencies for standard RoPE.

Parameters
  • base (int) โ€“ The base value for the geometric progression of frequencies.

  • rotary_dim (int) โ€“ The dimension of the rotary embeddings.

Returns

An array of inverse frequencies of shape (rotary_dim // 2,).

Return type

jnp.ndarray

easydel.layers.rotary_embedding.compute_deepseek_frequencies(base, rotary_dim, scaling_factor, extrapolation_factor, beta_fast, beta_slow, max_position_embeddings, mscale, mscale_all_dim, attn_factor) Array[source]#

Computes RoPE frequencies using the Deepseek-YaRN scaling method.

Similar to YaRN but potentially uses different mscale calculation parameters.

Parameters
  • base (float) โ€“ The base value for positional encoding.

  • rotary_dim (int) โ€“ The dimension of the rotary embeddings.

  • scaling_factor (float) โ€“ The factor by which the context length is scaled.

  • extrapolation_factor (float) โ€“ YaRN parameter controlling extrapolation strength.

  • beta_fast (int) โ€“ YaRN parameter for faster rotating dimensions.

  • beta_slow (int) โ€“ YaRN parameter for slower rotating dimensions.

  • max_position_embeddings (int) โ€“ Original maximum sequence length before scaling.

  • mscale (float) โ€“ Parameter for yarn_get_mscale calculation.

  • mscale_all_dim (float) โ€“ Parameter for yarn_get_mscale calculation.

  • attn_factor (float) โ€“ Scaling factor applied to attention outputs.

Returns

A frequency cache tensor of shape

(max_position_embeddings * scaling_factor, rotary_dim).

Return type

jnp.ndarray

easydel.layers.rotary_embedding.compute_dynamic_frequencies(base: int, rotary_dim: int, max_position_embeddings: int, scaling_factor: float)[source]#

Computes RoPE frequencies using Dynamic NTK scaling.

Adjusts the โ€˜baseโ€™ dynamically based on the scaling factor.

Parameters
  • base (int) โ€“ The initial base value before dynamic adjustment.

  • rotary_dim (int) โ€“ The dimension of the rotary embeddings.

  • max_position_embeddings (int) โ€“ The base maximum sequence length before scaling.

  • scaling_factor (float) โ€“ The scaling factor applied to the sequence length.

Returns

A frequency cache tensor of shape

(max_position_embeddings * scaling_factor, rotary_dim).

Return type

jnp.ndarray

easydel.layers.rotary_embedding.compute_linear_frequencies(base: int, rotary_dim: int, max_position_embeddings: int, scaling_factors: List[float])[source]#

Computes RoPE frequencies using linear scaling for potentially multiple factors.

This function computes frequency caches for each scaling factor and concatenates them. Note: This implementation seems designed for a specific use case where different parts of a sequence might use different scaling factors, determined by offsets. If only one scaling factor is used, it behaves like standard linear scaling.

Parameters
  • base (int) โ€“ The base value for the geometric progression of frequencies.

  • rotary_dim (int) โ€“ The dimension of the rotary embeddings.

  • max_position_embeddings (int) โ€“ The base maximum sequence length before scaling.

  • scaling_factors (tp.Union[tp.List[float], float]) โ€“ A single scaling factor or a list of scaling factors.

Returns

A frequency cache tensor. If multiple scaling factors are provided,

the caches are concatenated along the position dimension. Shape is (total_scaled_length, rotary_dim).

Return type

jnp.ndarray

easydel.layers.rotary_embedding.compute_llama3_frequencies(base, rotary_dim, low_freq_factor, high_freq_factor, scaling_factor, max_position_embeddings: int)[source]#

Computes RoPE frequencies using the Llama3 scaling method.

Parameters
  • base (float) โ€“ The base value for positional encoding.

  • rotary_dim (int) โ€“ The dimension of the rotary embeddings.

  • low_freq_factor (float) โ€“ Factor for adjusting low-frequency components.

  • high_freq_factor (float) โ€“ Factor for adjusting high-frequency components.

  • scaling_factor (float) โ€“ The overall scaling factor applied.

  • max_position_embeddings (int) โ€“ Original maximum sequence length (referred to as orig_max_position in compute_llama3_inv_frequencies). This defines the length of the frequency cache.

Returns

A frequency cache tensor of shape (max_position_embeddings, rotary_dim).

Return type

jnp.ndarray

easydel.layers.rotary_embedding.compute_llama3_inv_frequencies(base, rotary_dim, low_freq_factor, high_freq_factor, orig_max_position, scaling_factor)[source]#

Computes the inverse frequencies for Llama3-style scaled RoPE.

Adjusts frequencies based on wavelength thresholds and a smoothing factor.

Parameters
  • base (float) โ€“ The base value for positional encoding.

  • rotary_dim (int) โ€“ The dimension of the rotary embeddings.

  • low_freq_factor (float) โ€“ Factor for adjusting low-frequency components.

  • high_freq_factor (float) โ€“ Factor for adjusting high-frequency components.

  • orig_max_position (int) โ€“ Original maximum sequence length before scaling.

  • scaling_factor (float) โ€“ The overall scaling factor applied.

Returns

An array of Llama3-adjusted inverse frequencies of shape (rotary_dim // 2,).

Return type

jnp.ndarray

easydel.layers.rotary_embedding.compute_phi3_frequencies(base, head_size, rotary_dim, max_position_embeddings, original_max_position_embeddings, short_factor, long_factor)[source]#

Computes RoPE frequencies using the Phi-3 LongRoPE scaling method.

Applies different scaling factors based on whether the target length is shorter or longer than the original max length. Includes a scaling factor adjustment based on the ratio of target length to original length.

Parameters
  • base (float) โ€“ The base value for positional encoding.

  • head_size (int) โ€“ The dimension of each attention head.

  • rotary_dim (int) โ€“ The dimension of the rotary embeddings. Must equal head_size for Phi-3.

  • max_position_embeddings (int) โ€“ The target maximum sequence length after scaling.

  • original_max_position_embeddings (int) โ€“ Original maximum sequence length before scaling.

  • short_factor (tp.List[float]) โ€“ Scaling factors for frequencies when max_position_embeddings <= original_max_position_embeddings.

  • long_factor (tp.List[float]) โ€“ Scaling factors for frequencies when max_position_embeddings > original_max_position_embeddings.

Returns

A frequency cache tensor of shape (1, max_position_embeddings, rotary_dim).

Return type

jnp.ndarray

Raises

ValueError โ€“ If rotary_dim does not equal head_size.

easydel.layers.rotary_embedding.compute_yarn_frequencies(base: float, rotary_dim: int, beta_fast: float, beta_slow: float, max_position_embeddings: int, scaling_factor: float, extrapolation_factor: float, attn_factor: float) Array[source]#

Computes RoPE frequencies using the YaRN scaling method.

Includes adjustments based on YaRN parameters and applies an mscale factor.

Parameters
  • base (float) โ€“ The base value for positional encoding.

  • rotary_dim (int) โ€“ The dimension of the rotary embeddings.

  • beta_fast (float) โ€“ YaRN parameter for faster rotating dimensions.

  • beta_slow (float) โ€“ YaRN parameter for slower rotating dimensions.

  • max_position_embeddings (int) โ€“ Original maximum sequence length before scaling.

  • scaling_factor (float) โ€“ The factor by which the context length is scaled.

  • extrapolation_factor (float) โ€“ YaRN parameter controlling extrapolation strength.

  • attn_factor (float) โ€“ YaRN parameter scaling the attention outputs.

Returns

A frequency cache tensor of shape

(max_position_embeddings * scaling_factor, rotary_dim).

Return type

jnp.ndarray

easydel.layers.rotary_embedding.compute_yarn_inv_frequencies(base: float, rotary_dim: int, beta_fast: float, beta_slow: float, max_position_embeddings: int, scaling_factor: float, extrapolation_factor: float) Array[source]#

Computes the inverse frequencies for YaRN scaled RoPE.

Combines interpolation and extrapolation frequencies based on correction ranges.

Parameters
  • base (float) โ€“ The base value for positional encoding.

  • rotary_dim (int) โ€“ The dimension of the rotary embeddings.

  • beta_fast (float) โ€“ YaRN parameter for faster rotating dimensions.

  • beta_slow (float) โ€“ YaRN parameter for slower rotating dimensions.

  • max_position_embeddings (int) โ€“ Original maximum sequence length before scaling.

  • scaling_factor (float) โ€“ The factor by which the context length is scaled.

  • extrapolation_factor (float) โ€“ YaRN parameter controlling extrapolation strength.

Returns

An array of YaRN-adjusted inverse frequencies of shape (rotary_dim // 2,).

Return type

jnp.ndarray

easydel.layers.rotary_embedding.get_frequencies(head_size: int, rotary_dim: int, max_position: int, base: int, rope_scaling: Optional[Dict[str, Any]] = None, partial_rotary_factor: float = 1.0) Array[source]#

Computes and returns the RoPE frequency cache based on configuration.

Selects the appropriate frequency computation function (basic, linear, dynamic, YaRN, Llama3, Phi3, Deepseek) based on the rope_scaling dictionary. This function is JIT-compiled for performance, with relevant parameters marked static.

Parameters
  • head_size (int) โ€“ Dimension of each attention head (needed for some scaling types like Phi3).

  • rotary_dim (int) โ€“ Base dimension for rotary embedding (before partial factor).

  • max_position (int) โ€“ Maximum sequence length for which to compute frequencies. This might be the original or target length depending on scaling type.

  • base (int) โ€“ Base value for frequency calculation.

  • rope_scaling (tp.Optional[tp.Dict[str, tp.Any]], optional) โ€“ Dictionary specifying the type and parameters of RoPE scaling. Determines which frequency function to call. Defaults to None (uses compute_basic_frequencies).

  • partial_rotary_factor (float, optional) โ€“ Factor to reduce the rotary dimension. Defaults to 1.0.

Returns

The computed frequency cache tensor. Shape depends on the scaling method,

typically [computed_length, rotary_dim_effective].

Return type

jax.Array

Raises

ValueError โ€“ If rope_scaling specifies an unknown rope_type.

easydel.layers.rotary_embedding.get_inv_frequencies(head_size: int, rotary_dim: int, max_position: int, base: int, rope_scaling: Optional[Dict[str, Any]] = None, partial_rotary_factor: float = 1.0) Array[source]#

Computes and returns just the inverse frequencies for RoPE based on configuration.

Similar to get_frequencies but returns only the inverse frequencies without computing the full frequency cache (no cos/sin transformation).

Parameters
  • head_size (int) โ€“ Dimension of each attention head (needed for some scaling types like Phi3).

  • rotary_dim (int) โ€“ Base dimension for rotary embedding (before partial factor).

  • max_position (int) โ€“ Maximum sequence length the model should support.

  • base (int) โ€“ Base value for frequency calculation.

  • rope_scaling (tp.Optional[tp.Dict[str, tp.Any]], optional) โ€“ Dictionary specifying the type and parameters of RoPE scaling. Determines which frequency function to call. Defaults to None (uses basic inverse frequencies).

  • partial_rotary_factor (float, optional) โ€“ Factor to reduce the rotary dimension. Defaults to 1.0.

Returns

The computed inverse frequencies. Shape is typically (rotary_dim // 2,).

Return type

jax.Array

Raises

ValueError โ€“ If rope_scaling specifies an unknown rope_type.

easydel.layers.rotary_embedding.get_rope(head_size: int, rotary_dim: int, max_position: int, base: int, is_neox_style: bool = True, rope_scaling: Optional[Dict[str, Any]] = None, dtype: Optional[dtype] = None, partial_rotary_factor: float = 1.0) RotaryEmbedding[source]#

Factory function to create and return a RotaryEmbedding instance based on configuration.

Selects the appropriate RoPE class (standard, linear, dynamic, YaRN, Llama3, Phi3, Deepseek) based on the rope_scaling dictionary.

Parameters
  • head_size (int) โ€“ Dimension of each attention head.

  • rotary_dim (int) โ€“ Base dimension for rotary embedding (before partial factor).

  • max_position (int) โ€“ Maximum sequence length the model should support (target length).

  • base (int) โ€“ Base value for frequency calculation.

  • is_neox_style (bool, optional) โ€“ Use Neox rotation style. Defaults to True.

  • rope_scaling (tp.Optional[tp.Dict[str, tp.Any]], optional) โ€“ Dictionary specifying the type and parameters of RoPE scaling. If None or โ€˜rope_typeโ€™ is โ€˜defaultโ€™, uses standard RoPE. Keys like โ€˜rope_typeโ€™, โ€˜factorโ€™, โ€˜original_max_position_embeddingsโ€™, etc., are used. Defaults to None.

  • dtype (tp.Optional[jnp.dtype], optional) โ€“ Data type for embeddings. Defaults to jnp.float32.

  • partial_rotary_factor (float, optional) โ€“ Factor to reduce the rotary dimension (e.g., 0.5 applies RoPE to half the dimensions). Defaults to 1.0.

Returns

An instance of the configured RotaryEmbedding subclass.

Return type

RotaryEmbedding

Raises

ValueError โ€“ If rope_scaling specifies an unknown rope_type.

easydel.layers.rotary_embedding.rope_wraper(type)[source]#

A decorator factory that registers a RotaryEmbedding class under a specific type name.

This allows retrieving RoPE configurations by type name later. It also sets basic __str__ and __repr__ for the decorated class.

Parameters

type (str) โ€“ The name to register the RoPE class under (e.g., โ€œlinearโ€, โ€œyarnโ€).

Returns

A decorator function that takes a RotaryEmbedding class, registers it,

and returns the class.

Return type

Callable

easydel.layers.rotary_embedding.yarn_get_mscale(scale: float = 1, mscale: float = 1) float[source]#

Calculates the mscale factor, potentially used by Deepseek-YaRN or similar methods.

Allows specifying an additional mscale parameter compared to _yarn_get_mscale.

Parameters
  • scale (float, optional) โ€“ The scaling factor. Defaults to 1.

  • mscale (float, optional) โ€“ An additional scaling parameter. Defaults to 1.

Returns

The calculated mscale value. Returns 1.0 if scale <= 1.

Return type

float