Source code for easydel.layers.rotary_embedding

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import math
import typing as tp

import chex
import jax
import jax.numpy as jnp
from flax import nnx as nn
from jaxtyping import Array, Float

from easydel.utils.compiling_utils import ejit


@jax.named_scope("easydel-rotary-yarn-find-correction-dim")
def _yarn_find_correction_dim(
    num_rotations: int,
    dim: int,
    base: float = 10000,
    max_position_embeddings: int = 2048,
) -> float:
    """
    Calculates the correction dimension for YaRN scaling.

    Internal helper function for YaRN.

    Args:
        num_rotations (int): Number of rotations.
        dim (int): The dimension of the embeddings.
        base (float, optional): The base value for positional encoding. Defaults to 10000.
        max_position_embeddings (int, optional): The maximum sequence length. Defaults to 2048.

    Returns:
        float: The calculated correction dimension.
    """
    return (
        dim
        * jnp.log(
            max_position_embeddings / (num_rotations * 2 * jnp.pi),
        )
    ) / (2 * jnp.log(base))


@jax.named_scope("easydel-rotary-yarn-find-correction-range")
def _yarn_find_correction_range(
    low_rot: int,
    high_rot: int,
    dim: int,
    base: float = 10000,
    max_position_embeddings: int = 2048,
) -> tuple[int, int]:
    """
    Finds the correction range for YaRN scaling based on low and high rotation frequencies.

    Internal helper function for YaRN.

    Args:
        low_rot (int): Lower rotation frequency boundary.
        high_rot (int): Higher rotation frequency boundary.
        dim (int): The dimension of the embeddings.
        base (float, optional): The base value for positional encoding. Defaults to 10000.
        max_position_embeddings (int, optional): The maximum sequence length. Defaults to 2048.

    Returns:
        tp.Tuple[int, int]: A tuple containing the lower and upper bounds of the correction range,
                            clipped between 0 and dim-1.
    """
    hr = jnp.ceil(
        _yarn_find_correction_dim(
            high_rot,
            dim,
            base,
            max_position_embeddings,
        )
    )
    lr = jnp.floor(
        _yarn_find_correction_dim(
            low_rot,
            dim,
            base,
            max_position_embeddings,
        )
    )
    return jax.lax.max(lr, 0.0), jax.lax.min(hr, jnp.array(dim - 1, dtype=jnp.float32))


@jax.named_scope("easydel-rotary-yarn-linear-ramp-mask")
def _yarn_linear_ramp_mask(
    low: float,
    high: float,
    dim: int,
    dtype: jnp.dtype,
) -> jnp.ndarray:
    """
    Creates a linear ramp mask for YaRN scaling.

    Internal helper function for YaRN. Generates a mask that ramps linearly from 0 to 1
    between the `low` and `high` dimension indices.

    Args:
        low (float): The starting dimension index for the ramp.
        high (float): The ending dimension index for the ramp.
        dim (int): The total dimension of the mask.
        dtype (jnp.dtype): The data type for the mask array.

    Returns:
        jnp.ndarray: A 1D array of shape (dim,) representing the linear ramp mask,
                     clipped between 0 and 1.
    """
    high = jax.lax.cond(low == high, lambda x: x + 0.001, lambda x: x, high)
    linear_func = (jnp.arange(dim, dtype=dtype) - low) / (high - low)
    ramp_func = jnp.clip(linear_func, 0, 1)
    return ramp_func


@jax.named_scope("easydel-rotary-yarn-get-mscale")
def _yarn_get_mscale(scale: float = 1) -> float:
    """
    Calculates the mscale factor for YaRN context extension method.

    Internal helper function for YaRN.

    Args:
        scale (float, optional): The scaling factor. Defaults to 1.

    Returns:
        float: The calculated mscale value. Returns 1.0 if scale <= 1.
    """
    if scale <= 1:
        return 1.0
    return 0.1 * jnp.log(scale) + 1.0


@jax.named_scope("easydel-rotary-rotate-neox")
def _rotate_neox(x: Float[Array, "... seq_len head_dim"]) -> Float[Array, "... seq_len head_dim"]:
    """
    Applies the Neox-style rotation to the input array.

    Splits the last dimension in half and concatenates the negated second half
    with the first half.

    Args:
        x (jnp.ndarray): The input array.

    Returns:
        jnp.ndarray: The rotated array.
    """
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return jnp.concatenate((-x2, x1), axis=-1)


@jax.named_scope("easydel-rotary-rotate-gptj")
def _rotate_gptj(x: Float[Array, "... seq_len head_dim"]) -> Float[Array, "... seq_len head_dim"]:
    """
    Applies the GPT-J-style rotation to the input array.

    Interleaves the negated odd-indexed elements with the even-indexed elements.

    Args:
        x (jnp.ndarray): The input array.

    Returns:
        jnp.ndarray: The rotated array.
    """
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    x = jnp.stack((-x2, x1), axis=-1)
    return x.reshape((*x.shape[:-2], -1))


@jax.named_scope("easydel-rotary-apply-rotary-emb")
def _apply_rotary_emb(
    x: jnp.ndarray,
    cos: jnp.ndarray,
    sin: jnp.ndarray,
    is_neox_style: bool,
) -> jnp.ndarray:
    """
    Applies rotary positional embedding to the input tensor.

    Args:
        x (jnp.ndarray): Input tensor, e.g., query or key. Expected shape
                         [..., num_tokens, head_size] or similar.
        cos (jnp.ndarray): Cosine components of the embedding. Expected shape
                           compatible for broadcasting with `x` after rotation,
                           e.g., [..., num_tokens, head_size//2].
        sin (jnp.ndarray): Sine components of the embedding. Expected shape
                           compatible for broadcasting with `x` after rotation,
                           e.g., [..., num_tokens, head_size//2].
        is_neox_style (bool): Whether to use Neox-style rotation (`_rotate_neox`)
                              or GPT-J-style rotation (`_rotate_gptj`).

    Returns:
        jnp.ndarray: The tensor with rotary embeddings applied.
    """
    cos = cos[:, :, None].astype(x.dtype)
    sin = sin[:, :, None].astype(x.dtype)
    assert sin.ndim == x.ndim
    if is_neox_style:
        x1, x2 = jnp.split(x, 2, axis=-1)
    else:
        x1 = x[..., ::2]
        x2 = x[..., 1::2]

    o1 = x1 * cos - x2 * sin
    o2 = x2 * cos + x1 * sin

    if is_neox_style:
        return jnp.concatenate((o1, o2), axis=-1)
    else:
        return jnp.stack((o1, o2), axis=-1).reshape(x.shape)


AVAILABLE_ROPE_TYPES = {}
"""A dictionary to store registered RoPE (Rotary Position Embedding) types and their configurations."""


[docs]def rope_wraper(type): # noqa """ A decorator factory that registers a RotaryEmbedding class under a specific type name. This allows retrieving RoPE configurations by type name later. It also sets basic __str__ and __repr__ for the decorated class. Args: type (str): The name to register the RoPE class under (e.g., "linear", "yarn"). Returns: Callable: A decorator function that takes a RotaryEmbedding class, registers it, and returns the class. """ def w(rope: RotaryEmbedding): """ Decorator function that registers the RoPE class. Args: rope (RotaryEmbedding): The RotaryEmbedding class to register. Returns: RotaryEmbedding: The registered RotaryEmbedding class. """ properties = {k: v for k, v in rope.__dict__.items()} AVAILABLE_ROPE_TYPES[type] = properties rope.__str__ = lambda cls: str(cls.__class__.__name__) rope.__repr__ = lambda cls: repr(cls.__class__.__name__) rope._type = type return rope return w
[docs]@jax.named_scope("easydel-rotary-compute-basic-inv-frequencies") def compute_basic_inv_frequencies(base: int, rotary_dim: int): """ Computes the inverse frequencies for standard RoPE. Args: base (int): The base value for the geometric progression of frequencies. rotary_dim (int): The dimension of the rotary embeddings. Returns: jnp.ndarray: An array of inverse frequencies of shape (rotary_dim // 2,). """ return 1.0 / (base ** (jnp.arange(0, rotary_dim, 2, dtype="f4") / rotary_dim))
[docs]@jax.named_scope("easydel-rotary-compute-yarn-inv-frequencies") def compute_yarn_inv_frequencies( base: float, rotary_dim: int, beta_fast: float, beta_slow: float, max_position_embeddings: int, scaling_factor: float, extrapolation_factor: float, ) -> jnp.ndarray: """ Computes the inverse frequencies for YaRN scaled RoPE. Combines interpolation and extrapolation frequencies based on correction ranges. Args: base (float): The base value for positional encoding. rotary_dim (int): The dimension of the rotary embeddings. beta_fast (float): YaRN parameter for faster rotating dimensions. beta_slow (float): YaRN parameter for slower rotating dimensions. max_position_embeddings (int): Original maximum sequence length before scaling. scaling_factor (float): The factor by which the context length is scaled. extrapolation_factor (float): YaRN parameter controlling extrapolation strength. Returns: jnp.ndarray: An array of YaRN-adjusted inverse frequencies of shape (rotary_dim // 2,). """ pos_freqs = base ** (jnp.arange(0, rotary_dim, 2, dtype=jnp.float32) / rotary_dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) low, high = _yarn_find_correction_range( low_rot=beta_fast, high_rot=beta_slow, dim=rotary_dim, base=base, max_position_embeddings=max_position_embeddings, ) inv_frequencies_mask = ( 1 - _yarn_linear_ramp_mask(low, high, rotary_dim // 2, dtype=jnp.float32) ) * extrapolation_factor inv_frequencies = inv_freq_interpolation * (1 - inv_frequencies_mask) + inv_freq_extrapolation * inv_frequencies_mask return inv_frequencies
[docs]@jax.named_scope("easydel-rotary-compute-llama3-inv-frequencies") def compute_llama3_inv_frequencies( base, rotary_dim, low_freq_factor, high_freq_factor, orig_max_position, scaling_factor, ): """ Computes the inverse frequencies for Llama3-style scaled RoPE. Adjusts frequencies based on wavelength thresholds and a smoothing factor. Args: base (float): The base value for positional encoding. rotary_dim (int): The dimension of the rotary embeddings. low_freq_factor (float): Factor for adjusting low-frequency components. high_freq_factor (float): Factor for adjusting high-frequency components. orig_max_position (int): Original maximum sequence length before scaling. scaling_factor (float): The overall scaling factor applied. Returns: jnp.ndarray: An array of Llama3-adjusted inverse frequencies of shape (rotary_dim // 2,). """ inv_freqs = compute_basic_inv_frequencies(base, rotary_dim) low_freq_wavelen = orig_max_position / low_freq_factor high_freq_wavelen = orig_max_position / high_freq_factor wave_len = 2 * jnp.pi / inv_freqs if low_freq_factor != high_freq_factor: smooth = (orig_max_position / wave_len - low_freq_factor) / (high_freq_factor - low_freq_factor) else: smooth = 0 new_freqs = jnp.where( wave_len < high_freq_wavelen, inv_freqs, jnp.where( wave_len > low_freq_wavelen, inv_freqs / scaling_factor, (1 - smooth) * inv_freqs / scaling_factor + smooth * inv_freqs, ), ) return new_freqs
[docs]@jax.named_scope("easydel-rotary-compute-basic-frequencies") def compute_basic_frequencies( base: int, rotary_dim: int, max_position_embeddings: int, ): """ Computes the basic RoPE frequencies (cos and sin values) for all positions. Args: base (int): The base value for the geometric progression of frequencies. rotary_dim (int): The dimension of the rotary embeddings. max_position_embeddings (int): The maximum sequence length. Returns: jnp.ndarray: A frequency cache tensor of shape (max_position_embeddings, rotary_dim). Contains concatenated cos and sin values. """ inv = compute_basic_inv_frequencies(base, rotary_dim) freqs = jnp.einsum( "i,j -> ij", jnp.arange(max_position_embeddings, dtype=jnp.float32), inv, ) freqs = jnp.concatenate([jnp.cos(freqs), jnp.sin(freqs)], axis=-1) return freqs
[docs]@jax.named_scope("easydel-rotary-compute-linear-frequencies") def compute_linear_frequencies( base: int, rotary_dim: int, max_position_embeddings: int, scaling_factors: list[float], ): """ Computes RoPE frequencies using linear scaling for potentially multiple factors. This function computes frequency caches for each scaling factor and concatenates them. Note: This implementation seems designed for a specific use case where different parts of a sequence might use different scaling factors, determined by offsets. If only one scaling factor is used, it behaves like standard linear scaling. Args: base (int): The base value for the geometric progression of frequencies. rotary_dim (int): The dimension of the rotary embeddings. max_position_embeddings (int): The base maximum sequence length before scaling. scaling_factors (tp.Union[tp.List[float], float]): A single scaling factor or a list of scaling factors. Returns: jnp.ndarray: A frequency cache tensor. If multiple scaling factors are provided, the caches are concatenated along the position dimension. Shape is (total_scaled_length, rotary_dim). """ if not isinstance(scaling_factors, list): scaling_factors = [scaling_factors] inv_freq = compute_basic_inv_frequencies( base=base, rotary_dim=rotary_dim, ) cache_list: list[jnp.ndarray] = [] offsets: list[int] = [] for scaling_factor in scaling_factors: max_len = max_position_embeddings * scaling_factor t = jnp.arange(max_len, dtype=jnp.float32) t = t / scaling_factor freqs = jnp.einsum("i,j -> ij", t, inv_freq) cache = jnp.concatenate([jnp.cos(freqs), jnp.sin(freqs)], axis=-1) if not cache_list: offset = 0 else: last_offset = offsets[-1] next_max_len = cache_list[-1].shape[0] offset = last_offset + next_max_len offsets.append(offset) cache_list.append(cache) assert len(scaling_factors) == len(offsets) return jnp.concatenate(cache_list, axis=0)
[docs]@jax.named_scope("easydel-rotary-compute-dynamic-frequencies") def compute_dynamic_frequencies( base: int, rotary_dim: int, max_position_embeddings: int, scaling_factor: float, ): """ Computes RoPE frequencies using Dynamic NTK scaling. Adjusts the 'base' dynamically based on the scaling factor. Args: base (int): The initial base value before dynamic adjustment. rotary_dim (int): The dimension of the rotary embeddings. max_position_embeddings (int): The base maximum sequence length before scaling. scaling_factor (float): The scaling factor applied to the sequence length. Returns: jnp.ndarray: A frequency cache tensor of shape (max_position_embeddings * scaling_factor, rotary_dim). """ max_length = max_position_embeddings * scaling_factor base = base * ((scaling_factor * max_length / max_position_embeddings) - (scaling_factor - 1)) ** ( rotary_dim / (rotary_dim - 2) ) inv_frequencies = compute_basic_inv_frequencies(base=base, rotary_dim=rotary_dim) times = jnp.arange(max_length, dtype=jnp.float32) frequencies = jnp.einsum("i,j -> ij", times, inv_frequencies) return jnp.concatenate([jnp.cos(frequencies), jnp.sin(frequencies)], -1)
[docs]@jax.named_scope("easydel-rotary-compute-yarn-frequencies") def compute_yarn_frequencies( base: float, rotary_dim: int, beta_fast: float, beta_slow: float, max_position_embeddings: int, scaling_factor: float, extrapolation_factor: float, attn_factor: float, ) -> jnp.ndarray: """ Computes RoPE frequencies using the YaRN scaling method. Includes adjustments based on YaRN parameters and applies an mscale factor. Args: base (float): The base value for positional encoding. rotary_dim (int): The dimension of the rotary embeddings. beta_fast (float): YaRN parameter for faster rotating dimensions. beta_slow (float): YaRN parameter for slower rotating dimensions. max_position_embeddings (int): Original maximum sequence length before scaling. scaling_factor (float): The factor by which the context length is scaled. extrapolation_factor (float): YaRN parameter controlling extrapolation strength. attn_factor (float): YaRN parameter scaling the attention outputs. Returns: jnp.ndarray: A frequency cache tensor of shape (max_position_embeddings * scaling_factor, rotary_dim). """ inv_freq = compute_yarn_inv_frequencies( base=base, rotary_dim=rotary_dim, beta_fast=beta_fast, beta_slow=beta_slow, max_position_embeddings=max_position_embeddings, scaling_factor=scaling_factor, extrapolation_factor=extrapolation_factor, ) t = jnp.arange(max_position_embeddings * scaling_factor, dtype=jnp.float32) freqs = jnp.einsum("i,j -> ij", t, inv_freq) mscale = _yarn_get_mscale(scaling_factor) * attn_factor cos = jnp.cos(freqs) * mscale sin = jnp.sin(freqs) * mscale return jnp.concatenate([cos, sin], axis=-1)
[docs]@jax.named_scope("easydel-rotary-compute-phi3-frequencies") def compute_phi3_frequencies( base, head_size, rotary_dim, max_position_embeddings, original_max_position_embeddings, short_factor, long_factor, ): """ Computes RoPE frequencies using the Phi-3 LongRoPE scaling method. Applies different scaling factors based on whether the target length is shorter or longer than the original max length. Includes a scaling factor adjustment based on the ratio of target length to original length. Args: base (float): The base value for positional encoding. head_size (int): The dimension of each attention head. rotary_dim (int): The dimension of the rotary embeddings. Must equal head_size for Phi-3. max_position_embeddings (int): The target maximum sequence length after scaling. original_max_position_embeddings (int): Original maximum sequence length before scaling. short_factor (tp.List[float]): Scaling factors for frequencies when max_position_embeddings <= original_max_position_embeddings. long_factor (tp.List[float]): Scaling factors for frequencies when max_position_embeddings > original_max_position_embeddings. Returns: jnp.ndarray: A frequency cache tensor of shape (1, max_position_embeddings, rotary_dim). Raises: ValueError: If rotary_dim does not equal head_size. """ if rotary_dim != head_size: raise ValueError(f"rotary_dim != head_size ({rotary_dim}!={head_size})") if max_position_embeddings > original_max_position_embeddings: ext_factors = jnp.array(long_factor, dtype=jnp.float32) else: ext_factors = jnp.array(short_factor, dtype=jnp.float32) inv_freq_shape = jnp.arange(0, head_size, 2, dtype=jnp.int32).astype(jnp.float32) / head_size inv_freq = 1.0 / (ext_factors * (base**inv_freq_shape)) inv_freq_expanded = jnp.expand_dims(inv_freq, (0, 2)).astype(jnp.float32) position_ids = jnp.arange(max_position_embeddings, dtype=jnp.int32).reshape(1, -1) position_ids_expanded = jnp.expand_dims(position_ids, 1).astype(jnp.float32) freqs = (inv_freq_expanded @ position_ids_expanded).swapaxes(1, 2) emb = jnp.concatenate((freqs, freqs), axis=-1) scale = max_position_embeddings / original_max_position_embeddings if scale <= 1.0: scaling_factor = 1.0 else: scaling_factor = math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings)) cos = jnp.cos(emb) * scaling_factor sin = jnp.sin(emb) * scaling_factor return jnp.concatenate([cos, sin], axis=-1)
[docs]@jax.named_scope("easydel-rotary-compute-llama3-frequencies") def compute_llama3_frequencies( base, rotary_dim, low_freq_factor, high_freq_factor, scaling_factor, max_position_embeddings: int, ): """ Computes RoPE frequencies using the Llama3 scaling method. Args: base (float): The base value for positional encoding. rotary_dim (int): The dimension of the rotary embeddings. low_freq_factor (float): Factor for adjusting low-frequency components. high_freq_factor (float): Factor for adjusting high-frequency components. scaling_factor (float): The overall scaling factor applied. max_position_embeddings (int): Original maximum sequence length (referred to as `orig_max_position` in `compute_llama3_inv_frequencies`). This defines the length of the frequency cache. Returns: jnp.ndarray: A frequency cache tensor of shape (max_position_embeddings, rotary_dim). """ inv = compute_llama3_inv_frequencies( base, rotary_dim, low_freq_factor, high_freq_factor, max_position_embeddings, scaling_factor, ) freqs = jnp.einsum( "i,j -> ij", jnp.arange(max_position_embeddings, dtype=jnp.float32), inv, ) freqs = jnp.concatenate([jnp.cos(freqs), jnp.sin(freqs)], axis=-1) return freqs
[docs]@jax.named_scope("easydel-rotary-compute-deepseek-frequencies") def compute_deepseek_frequencies( base, rotary_dim, scaling_factor, extrapolation_factor, beta_fast, beta_slow, max_position_embeddings, mscale, mscale_all_dim, attn_factor, ) -> jnp.ndarray: """ Computes RoPE frequencies using the Deepseek-YaRN scaling method. Similar to YaRN but potentially uses different mscale calculation parameters. Args: base (float): The base value for positional encoding. rotary_dim (int): The dimension of the rotary embeddings. scaling_factor (float): The factor by which the context length is scaled. extrapolation_factor (float): YaRN parameter controlling extrapolation strength. beta_fast (int): YaRN parameter for faster rotating dimensions. beta_slow (int): YaRN parameter for slower rotating dimensions. max_position_embeddings (int): Original maximum sequence length before scaling. mscale (float): Parameter for `yarn_get_mscale` calculation. mscale_all_dim (float): Parameter for `yarn_get_mscale` calculation. attn_factor (float): Scaling factor applied to attention outputs. Returns: jnp.ndarray: A frequency cache tensor of shape (max_position_embeddings * scaling_factor, rotary_dim). """ pos_freqs = base ** (jnp.arange(0, rotary_dim, 2, dtype=jnp.float32) / rotary_dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) low, high = _yarn_find_correction_range( beta_fast, beta_slow, rotary_dim, base, max_position_embeddings, ) inv_freq_mask = (1 - _yarn_linear_ramp_mask(low, high, rotary_dim // 2, dtype=jnp.float32)) * extrapolation_factor inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask t = jnp.arange( max_position_embeddings * scaling_factor, dtype=jnp.float32, ) freqs = jnp.einsum("i,j -> ij", t, inv_freq) # DeepSeek mscale calculation attention_factor = ( yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(scaling_factor, mscale_all_dim) * attn_factor ) # Standard RoPE format: concatenate cos and sin return jnp.concatenate([jnp.cos(freqs) * attention_factor, jnp.sin(freqs) * attention_factor], axis=-1)
[docs]@jax.named_scope("easydel-rotary-apply-basic-rope") def apply_basic_rope( query: jax.Array, key: jax.Array, positions: jax.Array, frequencies: jax.Array, rotary_dim: int, is_neox_style: bool, offsets: jax.Array = None, dtype: jnp.dtype = jnp.float32, ): """ Applies standard or partially applied RoPE to query and key tensors. Selects frequencies based on positions (and optional offsets), then applies the rotation using `_apply_rotary_emb`. Handles cases where RoPE is applied only to a subset of the head dimension (`rotary_dim < query.shape[-1]`). Args: query (jax.Array): Query tensor. Shape [..., sequence_length, num_heads, head_dim]. key (jax.Array): Key tensor. Shape [..., sequence_length, num_heads, head_dim]. positions (jax.Array): Array of positions for lookup in the frequency cache. Shape [sequence_length]. frequencies (jax.Array): Precomputed frequency cache. Shape [max_length, rotary_dim_freq]. rotary_dim (int): The dimension up to which RoPE is applied. is_neox_style (bool): Whether to use Neox-style rotation. offsets (jax.Array, optional): Optional offsets to add to positions. Defaults to None. dtype (jnp.dtype, optional): Output dtype. Defaults to jnp.float32. Returns: tp.Tuple[jax.Array, jax.Array]: The rotated query and key tensors with the specified dtype. """ if offsets is not None: positions = positions + offsets cos, sin = jnp.split(frequencies[positions], 2, -1) if rotary_dim != query.shape[-1]: query_rot = _apply_rotary_emb(query[..., :rotary_dim], cos, sin, is_neox_style) query = jnp.concatenate((query_rot, query[..., rotary_dim:]), axis=-1) key_rot = _apply_rotary_emb(key[..., :rotary_dim], cos, sin, is_neox_style) key = jnp.concatenate((key_rot, key[..., rotary_dim:]), axis=-1) return query, key else: query = _apply_rotary_emb(query, cos, sin, is_neox_style) key = _apply_rotary_emb(key, cos, sin, is_neox_style) return query, key
[docs]@jax.named_scope("easydel-rotary-apply-phi3-rope") def apply_phi3_rope( query, key, positions, frequencies, offsets: jax.Array = None, dtype: jnp.dtype = jnp.float32, ): """ Applies Phi-3 LongRoPE to query and key tensors. Uses a specific rotation application style (`_rotate_neox`) assumed by Phi-3. Args: query (jax.Array): Query tensor. Shape [batch_size, sequence_length, num_heads, head_dim]. key (jax.Array): Key tensor. Shape [batch_size, sequence_length, num_heads, head_dim]. positions (jax.Array): Array of positions. Shape [sequence_length]. frequencies (jax.Array): Precomputed Phi-3 frequency cache. Shape [1, max_length, rotary_dim]. offsets (jax.Array, optional): Optional offsets to add to positions. Defaults to None. dtype (jnp.dtype, optional): Output dtype. Defaults to jnp.float32. Returns: tp.Tuple[jax.Array, jax.Array]: The rotated query and key tensors with the specified dtype. """ positions = positions if offsets is not None: positions = positions + offsets emb = frequencies[0, positions] cos, sin = jnp.split(emb, 2, axis=-1) cos = jnp.expand_dims(cos, 2) sin = jnp.expand_dims(sin, 2) with jax.default_matmul_precision("float32"): query_rot = query * cos + _rotate_neox(query) * sin key_rot = key * cos + _rotate_neox(key) * sin return query_rot.astype(dtype), key_rot.astype(dtype)
[docs]@rope_wraper("default") class RotaryEmbedding(nn.Module): """ Standard Rotary Positional Embedding (RoPE) module. Attributes: head_size (int): The dimension size of each attention head. rotary_dim (int): The dimension size of the rotary embeddings applied. Can be <= head_size. max_position_embeddings (int): The maximum sequence length the model can handle. base (int): The base value for calculating frequencies. is_neox_style (bool): Flag indicating whether to use Neox-style rotation. dtype (jnp.dtype): Data type for computations. """ def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, dtype: jnp.dtype, ): self.head_size = head_size self.rotary_dim = rotary_dim self.max_position_embeddings = max_position_embeddings self.base = base self.is_neox_style = is_neox_style self.dtype = dtype @jax.named_scope("easydel-rope-embedding") def __call__( self, positions: jnp.ndarray, query: jnp.ndarray, key: jnp.ndarray, offsets: jnp.ndarray | None = None, frequencies: jnp.ndarray | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray]: """__call__ pass for the rotary embedding.""" with jax.ensure_compile_time_eval(): if frequencies is None: frequencies = compute_basic_frequencies( base=self.base, rotary_dim=self.rotary_dim, max_position_embeddings=self.max_position_embeddings, ) if hasattr(frequencies, "value"): frequencies = frequencies.value return apply_basic_rope( query=query, key=key, positions=positions, frequencies=frequencies, rotary_dim=self.rotary_dim, is_neox_style=self.is_neox_style, offsets=offsets, dtype=self.dtype, )
[docs]@rope_wraper("mrope") class MultiModalRotaryEmbedding(RotaryEmbedding): """Multi-dimensional RoPE (MRoPE) with interleaved THW layout for Qwen2/3-VL models. MRoPE (Multi-dimensional Rotary Position Embedding) extends standard RoPE to handle 3D position information (Temporal, Height, Width) for vision-language models. The interleaving pattern reorganizes frequencies from chunked [TTT...HHH...WWW] to interleaved [T₀,H₀,W₀, T₁,H₁,W₁, ...], preserving frequency continuity for each spatial/temporal dimension. Attributes: mrope_section: Tuple of (T, H, W) dimensions specifying how many frequency components are allocated to each dimension. Default: (24, 20, 20) for 64-dim rotary embeddings (128 head_dim / 2). attention_scaling: Post-processing scaling factor applied to cos/sin. Default 1.0 for standard mRoPE. Can be set for advanced RoPE types. """ def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, dtype: jnp.dtype, mrope_section: tuple[int, int, int] | None = None, attention_scaling: float = 1.0, ): super().__init__( head_size=head_size, rotary_dim=rotary_dim, max_position_embeddings=max_position_embeddings, base=base, is_neox_style=is_neox_style, dtype=dtype, ) self.mrope_section = mrope_section if mrope_section is not None else (24, 20, 20) self.attention_scaling = attention_scaling def _apply_interleaved_mrope(self, freqs: jax.Array) -> jax.Array: """Interleave THW frequencies from chunked layout.""" freqs_t = freqs[0] for dim_idx, offset in enumerate((1, 2), start=1): section_size = self.mrope_section[dim_idx] * 3 idx = slice(offset, section_size, 3) freqs_t = freqs_t.at[..., idx].set(freqs[dim_idx, ..., idx]) return freqs_t @jax.named_scope("easydel-mrope") def __call__( self, positions: jax.Array, query: jax.Array, key: jax.Array, offsets: jax.Array | None = None, frequencies: jax.Array | None = None, ) -> tuple[jax.Array, jax.Array]: """Apply interleaved THW MRoPE to query/key. Args: positions: Position IDs with shape (batch, seq) or (3, batch, seq). If 2D, broadcasts to 3D with same positions for T, H, W. For vision-language tasks, should be (3, batch, seq) with separate T, H, W positions computed via get_rope_index. query: Query tensor to apply rotary embedding to. key: Key tensor to apply rotary embedding to. offsets: Optional position offsets (e.g., for KV cache). frequencies: Optional pre-computed frequency cache. Returns: Tuple of (rotated_query, rotated_key) with same dtype as input. """ # Normalize positions to (3, batch, seq) if positions.ndim == 2: positions = jnp.broadcast_to(positions[jnp.newaxis, ...], (3, *positions.shape)) elif positions.ndim != 3 or positions.shape[0] != 3: raise ValueError(f"Position IDs must have shape (batch, seq) or (3, batch, seq); got {positions.shape}.") if offsets is not None: positions = positions + offsets if frequencies is not None: freq_cache = getattr(frequencies, "value", frequencies) # freq_cache expected shape: [max_pos, rotary_dim] containing [cos, sin] concat freq_cache = jnp.asarray(freq_cache) freqs_full = jnp.stack( [ freq_cache[positions[0]], freq_cache[positions[1]], freq_cache[positions[2]], ], axis=0, ) # (3, b, seq, rotary_dim) cos_half, sin_half = jnp.split(freqs_full, 2, axis=-1) # each (3, b, seq, dim/2) cos_half = self._apply_interleaved_mrope(cos_half) sin_half = self._apply_interleaved_mrope(sin_half) cos = jnp.concatenate([cos_half, cos_half], axis=-1) sin = jnp.concatenate([sin_half, sin_half], axis=-1) else: inv_freq = compute_basic_inv_frequencies(self.base, self.rotary_dim) # (rotary_dim//2,) inv_freq = inv_freq[jnp.newaxis, jnp.newaxis, jnp.newaxis, :] freqs = positions[..., jnp.newaxis].astype(jnp.float32) * inv_freq # (3, b, seq, dim/2) freqs = self._apply_interleaved_mrope(freqs) # (b, seq, dim/2) cos_half = jnp.cos(freqs) sin_half = jnp.sin(freqs) cos = jnp.concatenate([cos_half, cos_half], axis=-1) sin = jnp.concatenate([sin_half, sin_half], axis=-1) # Apply attention scaling (typically 1.0 for standard mRoPE, can be different for advanced types) cos = cos * self.attention_scaling sin = sin * self.attention_scaling cos = cos[:, :, jnp.newaxis, :] sin = sin[:, :, jnp.newaxis, :] q_embed = (query * cos) + (_rotate_neox(query) * sin) k_embed = (key * cos) + (_rotate_neox(key) * sin) return q_embed.astype(self.dtype), k_embed.astype(self.dtype)
[docs]@rope_wraper("linear") class LinearScalingRotaryEmbedding(RotaryEmbedding): """ RotaryEmbedding extended with Linear Scaling. Linearly scales the position indices before calculating frequencies. Attributes: scaling_factors (tp.Union[tp.List[float], float]): The factor(s) to scale positions by. Inherits other attributes from RotaryEmbedding. """ def __init__( self, scaling_factors: list[float] | float, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, dtype: jnp.dtype, ): super().__init__( head_size=head_size, rotary_dim=rotary_dim, max_position_embeddings=max_position_embeddings, base=base, is_neox_style=is_neox_style, dtype=dtype, ) self.scaling_factors = scaling_factors @jax.named_scope("easydel-rope-linear-scaling") def __call__( self, positions: jnp.ndarray, query: jnp.ndarray, key: jnp.ndarray, offsets: jnp.ndarray | None = None, frequencies: jnp.ndarray | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray]: """__call__ pass for the rotary embedding.""" with jax.ensure_compile_time_eval(): if frequencies is None: frequencies = compute_linear_frequencies( base=self.base, rotary_dim=self.rotary_dim, max_position_embeddings=self.max_position_embeddings, scaling_factors=self.scaling_factors, ) if hasattr(frequencies, "value"): frequencies = frequencies.value return apply_basic_rope( query=query, key=key, positions=positions, frequencies=frequencies, rotary_dim=self.rotary_dim, is_neox_style=self.is_neox_style, offsets=offsets, dtype=self.dtype, )
[docs]@rope_wraper("dynamic") class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): """ RotaryEmbedding extended with Dynamic NTK scaling. Dynamically adjusts the `base` parameter based on the scaling factor. Attributes: scaling_factor (float): The scaling factor applied to sequence length and base calculation. Inherits other attributes from RotaryEmbedding. """ def __init__( self, scaling_factor: list[float] | float, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, dtype: jnp.dtype, ): super().__init__( head_size=head_size, rotary_dim=rotary_dim, max_position_embeddings=max_position_embeddings, base=base, is_neox_style=is_neox_style, dtype=dtype, ) self.scaling_factor = scaling_factor @jax.named_scope("easydel-rope-dynamic-ntk-scaling") def __call__( self, positions: jnp.ndarray, query: jnp.ndarray, key: jnp.ndarray, offsets: jnp.ndarray | None = None, frequencies: jnp.ndarray | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray]: """__call__ pass for the rotary embedding.""" with jax.ensure_compile_time_eval(): if frequencies is None: frequencies = compute_dynamic_frequencies( base=self.base, rotary_dim=self.rotary_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=self.scaling_factor, ) if hasattr(frequencies, "value"): frequencies = frequencies.value return apply_basic_rope( query=query, key=key, positions=positions, frequencies=frequencies, rotary_dim=self.rotary_dim, is_neox_style=self.is_neox_style, offsets=offsets, dtype=self.dtype, )
[docs]@rope_wraper("yarn") class YaRNScalingRotaryEmbedding(RotaryEmbedding): """ RotaryEmbedding extended with the YaRN (Yet another RoPE extensioN method) scaling. Combines interpolation and extrapolation with frequency correction and magnitude scaling. Attributes: scaling_factor (tp.Union[float, int]): The primary scaling factor for context length. extrapolation_factor (float): Controls the strength of extrapolation correction. attn_factor (float): Scales the output attention values. beta_fast (int): YaRN parameter for high-frequency dimensions correction range. beta_slow (int): YaRN parameter for low-frequency dimensions correction range. Inherits other attributes from RotaryEmbedding. Note: `max_position_embeddings` in the parent init likely refers to the *original* max length for YaRN calculations. """ def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, dtype: jnp.dtype, scaling_factor: float | int = 1.0, extrapolation_factor: float = 1.0, attn_factor: float = 1.0, beta_fast: int = 32, beta_slow: int = 1, ): super().__init__( head_size=head_size, rotary_dim=rotary_dim, max_position_embeddings=max_position_embeddings, base=base, is_neox_style=is_neox_style, dtype=dtype, ) self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow @jax.named_scope("easydel-rope-yarn-scaling") def __call__( self, positions: jnp.ndarray, query: jnp.ndarray, key: jnp.ndarray, offsets: jnp.ndarray | None = None, frequencies: jnp.ndarray | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray]: """__call__ pass for the rotary embedding.""" with jax.ensure_compile_time_eval(): if frequencies is None: frequencies = compute_yarn_frequencies( base=self.base, rotary_dim=self.rotary_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=self.scaling_factor, beta_fast=self.beta_fast, beta_slow=self.beta_slow, extrapolation_factor=self.extrapolation_factor, attn_factor=self.attn_factor, ) if hasattr(frequencies, "value"): frequencies = frequencies.value return apply_basic_rope( query=query, key=key, positions=positions, frequencies=frequencies, rotary_dim=self.rotary_dim, is_neox_style=self.is_neox_style, offsets=offsets, dtype=self.dtype, )
[docs]@rope_wraper("longrope") class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): """ RotaryEmbedding using the Phi-3 LongRoPE scaling method. Applies different frequency scaling factors (`short_factor`, `long_factor`) depending on the target sequence length relative to the original maximum. Requires `rotary_dim` to be equal to `head_size`. Attributes: head_size (int): Dimension of each attention head. Must equal rotary_dim. rotary_dim (int): Dimension subjected to rotary embedding. Must equal head_size. max_position_embeddings (int): The target maximum sequence length after scaling. original_max_position_embeddings (int): Original maximum sequence length before scaling. base (int): Base for frequency calculation. is_neox_style (bool): Flag indicating whether Neox-style rotation is assumed (used by `apply_phi3_rope`). dtype (jnp.dtype): Data type for computations. short_factor (tp.List[float]): Scaling factors applied when target length <= original max length. long_factor (tp.List[float]): Scaling factors applied when target length > original max length. """ def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, original_max_position_embeddings: int, base: int, is_neox_style: bool, dtype: jnp.dtype, short_factor: list[float], long_factor: list[float], ): super().__init__() self.head_size = head_size self.rotary_dim = rotary_dim self.max_position_embeddings = max_position_embeddings self.original_max_position_embeddings = original_max_position_embeddings self.base = base self.is_neox_style = is_neox_style self.dtype = dtype self.short_factor = short_factor self.long_factor = long_factor @jax.named_scope("easydel-rope-phi3-long") def __call__( self, positions: jnp.ndarray, query: jnp.ndarray, key: jnp.ndarray, offsets: jnp.ndarray | None = None, frequencies: jnp.ndarray | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray]: """__call__ pass for the rotary embedding.""" with jax.ensure_compile_time_eval(): if frequencies is None: frequencies = compute_phi3_frequencies( base=self.base, head_size=self.head_size, rotary_dim=self.rotary_dim, max_position_embeddings=self.max_position_embeddings, original_max_position_embeddings=self.original_max_position_embeddings, short_factor=self.short_factor, long_factor=self.long_factor, ) if hasattr(frequencies, "value"): frequencies = frequencies.value return apply_phi3_rope( query=query, key=key, positions=positions, frequencies=frequencies, offsets=offsets, dtype=self.dtype, )
[docs]@rope_wraper("llama3") class Llama3RotaryEmbedding(RotaryEmbedding): """ RotaryEmbedding implementing the Llama-3 scaling method. Adjusts frequencies based on wavelength thresholds (`low_freq_factor`, `high_freq_factor`) and applies an overall scaling factor. Attributes: scaling_factor (float): Overall scaling factor. low_freq_factor (float): Factor related to low frequency wavelength threshold. high_freq_factor (float): Factor related to high frequency wavelength threshold. orig_max_position (int): Original maximum sequence length before scaling. Inherits other attributes from RotaryEmbedding. """ def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, dtype: jnp.dtype, scaling_factor: float, low_freq_factor: float, high_freq_factor: float, orig_max_position: int, ): super().__init__( head_size=head_size, rotary_dim=rotary_dim, max_position_embeddings=max_position_embeddings, base=base, is_neox_style=is_neox_style, dtype=dtype, ) self.scaling_factor = scaling_factor self.low_freq_factor = low_freq_factor self.high_freq_factor = high_freq_factor self.orig_max_position = orig_max_position @jax.named_scope("easydel-rope-llama3") def __call__( self, positions: jnp.ndarray, query: jnp.ndarray, key: jnp.ndarray, offsets: jnp.ndarray | None = None, frequencies: jnp.ndarray | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray]: """__call__ pass for the rotary embedding.""" with jax.ensure_compile_time_eval(): if frequencies is None: frequencies = compute_llama3_frequencies( base=self.base, rotary_dim=self.rotary_dim, low_freq_factor=self.low_freq_factor, high_freq_factor=self.high_freq_factor, scaling_factor=self.scaling_factor, max_position_embeddings=self.orig_max_position, ) if hasattr(frequencies, "value"): frequencies = frequencies.value return apply_basic_rope( query=query, key=key, positions=positions, frequencies=frequencies, rotary_dim=self.rotary_dim, is_neox_style=self.is_neox_style, offsets=offsets, dtype=self.dtype, )
[docs]def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: """ Calculates the mscale factor, potentially used by Deepseek-YaRN or similar methods. Allows specifying an additional `mscale` parameter compared to `_yarn_get_mscale`. Args: scale (float, optional): The scaling factor. Defaults to 1. mscale (float, optional): An additional scaling parameter. Defaults to 1. Returns: float: The calculated mscale value. Returns 1.0 if scale <= 1. """ if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0
[docs]@rope_wraper("deepseek_yarn") class DeepseekScalingRotaryEmbedding(nn.Module): """ RotaryEmbedding implementing a YaRN-like scaling method, potentially from Deepseek models. Uses YaRN parameters (`beta_fast`, `beta_slow`, `extrapolation_factor`) and includes additional m-scale parameters (`mscale`, `mscale_all_dim`). This version has a custom `__call__` method differing slightly from `apply_basic_rope`. Attributes: head_size (int): Dimension of each attention head. rotary_dim (int): Dimension subjected to rotary embedding. max_position_embeddings (int): Original maximum sequence length before scaling. base (int): Base for frequency calculation. is_neox_style (bool): Use Neox rotation if True, GPT-J otherwise. dtype (jnp.dtype): Data type for embeddings. scaling_factor (float): Primary scaling factor. extrapolation_factor (float): YaRN extrapolation factor. attn_factor (float): Attention scaling factor. beta_fast (int): YaRN parameter. beta_slow (int): YaRN parameter. mscale (float): Parameter for m-scale calculation. mscale_all_dim (float): Parameter for m-scale calculation. """ def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, dtype: jnp.dtype, scaling_factor: float, extrapolation_factor: float = 1, attn_factor: float = 1, beta_fast: int = 32, beta_slow: int = 1, mscale: float = 1, mscale_all_dim: float = 0, ): self.head_size = head_size self.rotary_dim = rotary_dim self.max_position_embeddings = max_position_embeddings self.base = base self.is_neox_style = is_neox_style self.dtype = dtype self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow self.mscale = mscale self.mscale_all_dim = mscale_all_dim @jax.named_scope("easydel-rope-deepseek") def __call__( self, positions: jnp.ndarray, query: jnp.ndarray, key: jnp.ndarray, offsets: jnp.ndarray | None = None, frequencies: jnp.ndarray | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray]: if frequencies is None: frequencies = compute_deepseek_frequencies( self.base, self.rotary_dim, self.scaling_factor, self.extrapolation_factor, self.beta_fast, self.beta_slow, self.max_position_embeddings, self.mscale, self.mscale_all_dim, self.attn_factor, ) cos, sin = jnp.split(frequencies[positions], 2, -1) if offsets is not None: positions += offsets query_rot = query[..., : self.rotary_dim] key_rot = key[..., : self.rotary_dim] if self.rotary_dim < self.head_size: query_pass = query[..., self.rotary_dim :] key_pass = key[..., self.rotary_dim :] target_sc_shape = (query.shape[0], -1, 1, self.rotary_dim) if self.is_neox_style: cos = cos.repeat(2, axis=1).reshape(target_sc_shape) sin = sin.repeat(2, axis=1).reshape(target_sc_shape) else: cos = cos.repeat_interleave(2, axis=1).reshape(target_sc_shape) sin = sin.repeat_interleave(2, axis=1).reshape(target_sc_shape) rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj query_rot = query_rot * cos + rotate_fn(query_rot) * sin key_rot = key_rot * cos + rotate_fn(key_rot) * sin if self.rotary_dim < self.head_size: query = jnp.concatenate((query_rot, query_pass), axis=-1) key = jnp.concatenate((key_rot, key_pass), axis=-1) else: query = query_rot key = key_rot return query, key
[docs]def get_rope( head_size: int, rotary_dim: int, max_position: int, base: int, is_neox_style: bool = True, rope_scaling: dict[str, tp.Any] | None = None, dtype: jnp.dtype | None = None, partial_rotary_factor: float = 1.0, ) -> RotaryEmbedding: """ Factory function to create and return a RotaryEmbedding instance based on configuration. Selects the appropriate RoPE class (standard, linear, dynamic, YaRN, Llama3, Phi3, Deepseek) based on the `rope_scaling` dictionary. Args: head_size (int): Dimension of each attention head. rotary_dim (int): Base dimension for rotary embedding (before partial factor). max_position (int): Maximum sequence length the model should support (target length). base (int): Base value for frequency calculation. is_neox_style (bool, optional): Use Neox rotation style. Defaults to True. rope_scaling (tp.Optional[tp.Dict[str, tp.Any]], optional): Dictionary specifying the type and parameters of RoPE scaling. If None or 'rope_type' is 'default', uses standard RoPE. Keys like 'rope_type', 'factor', 'original_max_position_embeddings', etc., are used. Defaults to None. dtype (tp.Optional[jnp.dtype], optional): Data type for embeddings. Defaults to jnp.float32. partial_rotary_factor (float, optional): Factor to reduce the rotary dimension (e.g., 0.5 applies RoPE to half the dimensions). Defaults to 1.0. Returns: RotaryEmbedding: An instance of the configured RotaryEmbedding subclass. Raises: ValueError: If `rope_scaling` specifies an unknown `rope_type`. """ if dtype is None: dtype = jnp.float32 # Default JAX dtype if partial_rotary_factor < 1.0: rotary_dim = int(rotary_dim * partial_rotary_factor) if rope_scaling is None: rotary_emb = RotaryEmbedding( head_size=head_size, rotary_dim=rotary_dim, max_position_embeddings=max_position, base=base, is_neox_style=is_neox_style, dtype=dtype, ) else: scaling_type = rope_scaling["rope_type"] if "mrope_interleaved" in rope_scaling.keys() and "mrope_section" in rope_scaling.keys(): scaling_type = "mrope" if scaling_type == "llama3": scaling_factor = rope_scaling["factor"] low_freq_factor = rope_scaling["low_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"] original_max_position = rope_scaling["original_max_position_embeddings"] rotary_emb = Llama3RotaryEmbedding( head_size=head_size, rotary_dim=rotary_dim, max_position_embeddings=max_position, base=base, is_neox_style=is_neox_style, dtype=dtype, scaling_factor=scaling_factor, low_freq_factor=low_freq_factor, high_freq_factor=high_freq_factor, orig_max_position=original_max_position, ) elif scaling_type == "default": rotary_emb = RotaryEmbedding( head_size=head_size, rotary_dim=rotary_dim, max_position_embeddings=max_position, base=base, is_neox_style=is_neox_style, dtype=dtype, ) elif scaling_type == "linear": scaling_factor = rope_scaling["factor"] rotary_emb = LinearScalingRotaryEmbedding( head_size=head_size, rotary_dim=rotary_dim, max_position_embeddings=max_position, base=base, is_neox_style=is_neox_style, scaling_factors=scaling_factor, dtype=dtype, ) elif scaling_type == "dynamic": scaling_factor = rope_scaling["factor"] rotary_emb = DynamicNTKScalingRotaryEmbedding( head_size=head_size, rotary_dim=rotary_dim, max_position_embeddings=max_position, base=base, is_neox_style=is_neox_style, scaling_factor=scaling_factor, dtype=dtype, ) elif scaling_type == "yarn": scaling_factor = rope_scaling.get("factor", rope_scaling.get("scaling_factor")) assert scaling_factor is not None original_max_position = rope_scaling.get("original_max_position_embeddings", max_position) extra_kwargs = { k: v for k, v in rope_scaling.items() if k in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") } rotary_emb = YaRNScalingRotaryEmbedding( head_size=head_size, rotary_dim=rotary_dim, max_position_embeddings=original_max_position, base=base, is_neox_style=is_neox_style, scaling_factor=scaling_factor, dtype=dtype, **extra_kwargs, ) elif scaling_type == "deepseek_yarn": scaling_factor = rope_scaling["factor"] original_max_position = rope_scaling["original_max_position_embeddings"] extra_kwargs = { k: v for k, v in rope_scaling.items() if k in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow", "mscale", "mscale_all_dim") } rotary_emb = DeepseekScalingRotaryEmbedding( head_size=head_size, rotary_dim=rotary_dim, max_position_embeddings=original_max_position, base=base, is_neox_style=is_neox_style, scaling_factor=scaling_factor, dtype=dtype, **extra_kwargs, ) elif scaling_type == "longrope": short_factor = rope_scaling["short_factor"] long_factor = rope_scaling["long_factor"] original_max_position = rope_scaling["original_max_position_embeddings"] rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( head_size=head_size, rotary_dim=rotary_dim, max_position_embeddings=max_position, original_max_position_embeddings=original_max_position, base=base, is_neox_style=is_neox_style, dtype=dtype, short_factor=short_factor, long_factor=long_factor, ) elif scaling_type == "mrope": rotary_emb = MultiModalRotaryEmbedding( head_size=head_size, rotary_dim=rotary_dim, max_position_embeddings=max_position, base=base, is_neox_style=is_neox_style, dtype=dtype, mrope_section=rope_scaling.get("mrope_section"), ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") return rotary_emb
[docs]@ejit( static_argnames=[ "head_size", "rotary_dim", "max_position", "base", "rope_scaling", "partial_rotary_factor", ], ) def get_frequencies( head_size: int, rotary_dim: int, max_position: int, base: int, rope_scaling: dict[str, tp.Any] | None = None, partial_rotary_factor: float = 1.0, ) -> jax.Array: """ Computes and returns the RoPE frequency cache based on configuration. Selects the appropriate frequency computation function (basic, linear, dynamic, YaRN, Llama3, Phi3, Deepseek) based on the `rope_scaling` dictionary. This function is JIT-compiled for performance, with relevant parameters marked static. Args: head_size (int): Dimension of each attention head (needed for some scaling types like Phi3). rotary_dim (int): Base dimension for rotary embedding (before partial factor). max_position (int): Maximum sequence length for which to compute frequencies. This might be the original or target length depending on scaling type. base (int): Base value for frequency calculation. rope_scaling (tp.Optional[tp.Dict[str, tp.Any]], optional): Dictionary specifying the type and parameters of RoPE scaling. Determines which frequency function to call. Defaults to None (uses `compute_basic_frequencies`). partial_rotary_factor (float, optional): Factor to reduce the rotary dimension. Defaults to 1.0. Returns: jax.Array: The computed frequency cache tensor. Shape depends on the scaling method, typically [computed_length, rotary_dim_effective]. Raises: ValueError: If `rope_scaling` specifies an unknown `rope_type`. """ if partial_rotary_factor < 1.0: rotary_dim = int(rotary_dim * partial_rotary_factor) if rope_scaling is None: frequencies = compute_basic_frequencies( base=base, rotary_dim=rotary_dim, max_position_embeddings=max_position, ) else: scaling_type = rope_scaling["rope_type"] if "mrope_interleaved" in rope_scaling.keys() and "mrope_section" in rope_scaling.keys(): scaling_type = "mrope" if scaling_type == "llama3": scaling_factor = rope_scaling["factor"] low_freq_factor = rope_scaling["low_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"] original_max_position = rope_scaling["original_max_position_embeddings"] frequencies = compute_llama3_frequencies( base=base, rotary_dim=rotary_dim, low_freq_factor=low_freq_factor, high_freq_factor=high_freq_factor, scaling_factor=scaling_factor, max_position_embeddings=original_max_position, ) elif scaling_type == "default": frequencies = compute_basic_frequencies( base=base, rotary_dim=rotary_dim, max_position_embeddings=max_position, ) elif scaling_type == "linear": scaling_factors = rope_scaling["factor"] frequencies = compute_linear_frequencies( base=base, rotary_dim=rotary_dim, max_position_embeddings=max_position, scaling_factors=scaling_factors, ) elif scaling_type == "dynamic": scaling_factor = rope_scaling["factor"] frequencies = compute_dynamic_frequencies( rotary_dim=rotary_dim, max_position_embeddings=max_position, base=base, scaling_factor=scaling_factor, ) elif scaling_type == "yarn": scaling_factor = rope_scaling["factor"] original_max_position = rope_scaling.get("original_max_position_embeddings", max_position) # for gpt_oss extra_kwargs = { k: v for k, v in rope_scaling.items() if k in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow", "mscale", "mscale_all_dim") } # Check if this is DeepSeek-style YaRN (has mscale and mscale_all_dim parameters) if "mscale" in extra_kwargs and "mscale_all_dim" in extra_kwargs: frequencies = compute_deepseek_frequencies( base, rotary_dim, scaling_factor, extra_kwargs.get("extrapolation_factor", 1.0), extra_kwargs.get("beta_fast", 32), extra_kwargs.get("beta_slow", 1), original_max_position, extra_kwargs["mscale"], extra_kwargs["mscale_all_dim"], extra_kwargs.get("attn_factor", extra_kwargs.get("attention_factor", 1)), ) else: frequencies = compute_yarn_frequencies( base=base, rotary_dim=rotary_dim, beta_fast=extra_kwargs.get("beta_fast", 32), beta_slow=extra_kwargs.get("beta_slow", 1), max_position_embeddings=original_max_position, scaling_factor=scaling_factor, extrapolation_factor=extra_kwargs.get("extrapolation_factor", 1.0), attn_factor=extra_kwargs.get("attn_factor", extra_kwargs.get("attention_factor", 1)), ) elif scaling_type == "deepseek_yarn": scaling_factor = rope_scaling["factor"] original_max_position = rope_scaling["original_max_position_embeddings"] extra_kwargs = { k: v for k, v in rope_scaling.items() if k in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow", "mscale", "mscale_all_dim") } frequencies = compute_deepseek_frequencies( base, rotary_dim, scaling_factor, extra_kwargs.get("extrapolation_factor", 1.0), extra_kwargs.get("beta_fast", 32), extra_kwargs.get("beta_slow", 1), original_max_position, extra_kwargs["mscale"], extra_kwargs["mscale_all_dim"], extra_kwargs.get("attn_factor", extra_kwargs.get("attention_factor", 1)), ) elif scaling_type == "longrope": short_factor = rope_scaling["short_factor"] long_factor = rope_scaling["long_factor"] original_max_position = rope_scaling["original_max_position_embeddings"] extra_kwargs = {k: v for k, v in rope_scaling.items() if k in ("short_mscale", "long_mscale")} frequencies = compute_phi3_frequencies( base=base, head_size=head_size, rotary_dim=rotary_dim, max_position_embeddings=max_position, original_max_position_embeddings=original_max_position, short_factor=short_factor, long_factor=long_factor, ) elif scaling_type == "mrope": # Use basic cache; interleaving handled inside the MRoPE class frequencies = compute_basic_frequencies( base=base, rotary_dim=rotary_dim, max_position_embeddings=max_position, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") return frequencies
[docs]@ejit( static_argnames=[ "head_size", "rotary_dim", "max_position", "base", "rope_scaling", "partial_rotary_factor", ], ) def get_inv_frequencies( head_size: int, rotary_dim: int, max_position: int, base: int, rope_scaling: dict[str, tp.Any] | None = None, partial_rotary_factor: float = 1.0, ) -> jax.Array: """ Computes and returns just the inverse frequencies for RoPE based on configuration. Similar to `get_frequencies` but returns only the inverse frequencies without computing the full frequency cache (no cos/sin transformation). Args: head_size (int): Dimension of each attention head (needed for some scaling types like Phi3). rotary_dim (int): Base dimension for rotary embedding (before partial factor). max_position (int): Maximum sequence length the model should support. base (int): Base value for frequency calculation. rope_scaling (tp.Optional[tp.Dict[str, tp.Any]], optional): Dictionary specifying the type and parameters of RoPE scaling. Determines which frequency function to call. Defaults to None (uses basic inverse frequencies). partial_rotary_factor (float, optional): Factor to reduce the rotary dimension. Defaults to 1.0. Returns: jax.Array: The computed inverse frequencies. Shape is typically (rotary_dim // 2,). Raises: ValueError: If `rope_scaling` specifies an unknown `rope_type`. """ if partial_rotary_factor < 1.0: rotary_dim = int(rotary_dim * partial_rotary_factor) if rope_scaling is None: inv_frequencies = compute_basic_inv_frequencies(base=base, rotary_dim=rotary_dim) else: scaling_type = rope_scaling["rope_type"] if scaling_type == "llama3": scaling_factor = rope_scaling["factor"] low_freq_factor = rope_scaling["low_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"] original_max_position = rope_scaling["original_max_position_embeddings"] inv_frequencies = compute_llama3_inv_frequencies( base=base, rotary_dim=rotary_dim, low_freq_factor=low_freq_factor, high_freq_factor=high_freq_factor, orig_max_position=original_max_position, scaling_factor=scaling_factor, ) elif scaling_type == "default": inv_frequencies = compute_basic_inv_frequencies(base=base, rotary_dim=rotary_dim) elif scaling_type == "linear": inv_frequencies = compute_basic_inv_frequencies(base=base, rotary_dim=rotary_dim) elif scaling_type == "dynamic": scaling_factor = rope_scaling["factor"] adjusted_base = base * ((scaling_factor * max_position / max_position) - (scaling_factor - 1)) ** ( rotary_dim / (rotary_dim - 2) ) inv_frequencies = compute_basic_inv_frequencies(base=adjusted_base, rotary_dim=rotary_dim) elif scaling_type == "yarn": scaling_factor = rope_scaling["factor"] original_max_position = rope_scaling["original_max_position_embeddings"] extra_kwargs = { k: v for k, v in rope_scaling.items() if k in ("extrapolation_factor", "beta_fast", "beta_slow") } extrapolation_factor = extra_kwargs.get("extrapolation_factor", 1.0) beta_fast = extra_kwargs.get("beta_fast", 32) beta_slow = extra_kwargs.get("beta_slow", 1) inv_frequencies = compute_yarn_inv_frequencies( base=base, rotary_dim=rotary_dim, beta_fast=beta_fast, beta_slow=beta_slow, max_position_embeddings=original_max_position, scaling_factor=scaling_factor, extrapolation_factor=extrapolation_factor, ) elif scaling_type == "deepseek_yarn": scaling_factor = rope_scaling["factor"] original_max_position = rope_scaling["original_max_position_embeddings"] extra_kwargs = { k: v for k, v in rope_scaling.items() if k in ("extrapolation_factor", "beta_fast", "beta_slow") } extrapolation_factor = extra_kwargs.get("extrapolation_factor", 1.0) beta_fast = extra_kwargs.get("beta_fast", 32) beta_slow = extra_kwargs.get("beta_slow", 1) pos_freqs = base ** (jnp.arange(0, rotary_dim, 2, dtype=jnp.float32) / rotary_dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) low, high = _yarn_find_correction_range(beta_fast, beta_slow, rotary_dim, base, original_max_position) inv_freq_mask = ( 1 - _yarn_linear_ramp_mask(low, high, rotary_dim // 2, dtype=jnp.float32) ) * extrapolation_factor inv_frequencies = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask elif scaling_type == "longrope": short_factor = rope_scaling["short_factor"] long_factor = rope_scaling["long_factor"] original_max_position = rope_scaling["original_max_position_embeddings"] if max_position > original_max_position: ext_factors = jnp.array(long_factor, dtype=jnp.float32) else: ext_factors = jnp.array(short_factor, dtype=jnp.float32) inv_freq_shape = jnp.arange(0, head_size, 2, dtype=jnp.int32).astype(jnp.float32) / head_size inv_frequencies = 1.0 / (ext_factors * (base**inv_freq_shape)) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") return inv_frequencies
# Example usage if __name__ == "__main__": head_size = 64 rotary_dim = 64 max_position = 2048 base = 10000 is_neox_style = True dtype = jnp.float32 rope_scaling = { "rope_type": "yarn", "factor": 2.0, "original_max_position_embeddings": 1024, "extrapolation_factor": 1.0, "attn_factor": 1.0, "beta_fast": 32, "beta_slow": 1, } rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, rope_scaling, dtype) freq = get_frequencies(head_size, rotary_dim, max_position, base, rope_scaling)
[docs]@chex.dataclass class RopeConfig: """ Configuration class for RoPE (Rotary Position Embedding) parameters. Stores the configuration related to RoPE type and its scaling parameters, making it easy to manage and pass around RoPE settings. Attributes: rope_type (str): The type of RoPE scaling to use (e.g., "default", "linear", "yarn", "llama3"). Defaults to "default". factor (tp.Optional[float]): General scaling factor used by some types (linear, dynamic, yarn, llama3). low_freq_factor (tp.Optional[float]): Specific factor for Llama3 scaling. high_freq_factor (tp.Optional[float]): Specific factor for Llama3 scaling. original_max_position_embeddings (tp.Optional[int]): Original context window size, required by some scaling methods (yarn, llama3, phi3, deepseek). long_factor (tp.Optional[float]): Specific factor for Phi3 LongRoPE scaling (used for lengths > original). short_factor (tp.Optional[float]): Specific factor for Phi3 LongRoPE scaling (used for lengths <= original). long_mscale (tp.Optional[float]): Potentially used by variants like Phi3. (Not used in current `get_rope`). short_mscale (tp.Optional[float]): Potentially used by variants like Phi3. (Not used in current `get_rope`). # Add other potential scaling parameters here as needed (e.g., from YaRN, Deepseek) extrapolation_factor (tp.Optional[float]): YaRN/Deepseek parameter. attn_factor (tp.Optional[float]): YaRN/Deepseek parameter. beta_fast (tp.Optional[int]): YaRN/Deepseek parameter. beta_slow (tp.Optional[int]): YaRN/Deepseek parameter. mscale (tp.Optional[float]): Deepseek parameter. mscale_all_dim (tp.Optional[float]): Deepseek parameter. """ rope_type: str = "default" factor: float | None = None low_freq_factor: float | None = None high_freq_factor: float | None = None original_max_position_embeddings: int | None = None long_factor: float | None = None short_factor: float | None = None long_mscale: float | None = None short_mscale: float | None = None beta_fast: int | None = None beta_slow: int | None = None mscale: int | None = None mscale_all_dim: int | None = None mrope_interleaved: bool | None = None mrope_section: list[int] | None = None
[docs] @classmethod def from_dict(cls, config_dict: dict[str, tp.Any]) -> RopeConfig: """ Create a RopeConfig instance from a dictionary. Handles potential alias 'type' for 'rope_type'. Args: config_dict (tp.Dict[str, tp.Any]): Dictionary containing RoPE configuration. Returns: RopeConfig: An instance populated from the dictionary. """ return cls( rope_type=config_dict.get("rope_type") or config_dict.get("type", "default"), factor=config_dict.get("factor") or config_dict.get("scaling_factor"), low_freq_factor=config_dict.get("low_freq_factor"), high_freq_factor=config_dict.get("high_freq_factor"), original_max_position_embeddings=config_dict.get("original_max_position_embeddings"), long_factor=config_dict.get("long_factor"), short_factor=config_dict.get("short_factor"), long_mscale=config_dict.get("long_mscale"), short_mscale=config_dict.get("short_mscale"), beta_fast=config_dict.get("beta_fast"), beta_slow=config_dict.get("beta_slow"), mscale=config_dict.get("mscale"), mscale_all_dim=config_dict.get("mscale_all_dim"), mrope_interleaved=config_dict.get("mrope_interleaved"), mrope_section=config_dict.get("mrope_section"), )
[docs] def to_dict(self) -> dict[str, tp.Any]: """ Convert the RopeConfig instance to a dictionary. Filters out attributes with None values. The dictionary is made hashable using a custom class for potential use with JIT compilation contexts (though making the dict itself static in `get_frequencies` is preferred). Returns: tp.Dict[str, tp.Any]: A hashable dictionary containing non-None configuration values. """ from easydel.utils.compiling_utils import hash_fn class rope_scaling(dict): """A dictionary subclass that is hashable.""" __hash__ = hash_fn scale = rope_scaling({k: v for k, v in self.__dict__.items() if v is not None}) return scale