# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
import typing as tp
# from functools import partial
from functools import partial
import chex
import jax
import jax.numpy as jnp
from flax import nnx as nn
@jax.named_scope("easydel-rotary-yarn-find-correction-dim")
def _yarn_find_correction_dim(
num_rotations: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048,
) -> float:
return (
dim
* jnp.log(
max_position_embeddings / (num_rotations * 2 * jnp.pi),
)
) / (2 * jnp.log(base))
@jax.named_scope("easydel-rotary-yarn-find-correction-range")
def _yarn_find_correction_range(
low_rot: int,
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048,
) -> tp.Tuple[int, int]:
hr = jnp.ceil(
_yarn_find_correction_dim(
high_rot,
dim,
base,
max_position_embeddings,
)
)
lr = jnp.floor(
_yarn_find_correction_dim(
low_rot,
dim,
base,
max_position_embeddings,
)
)
return jax.lax.max(lr, 0.0), jax.lax.min(hr, jnp.array(dim - 1, dtype=jnp.float32))
@jax.named_scope("easydel-rotary-yarn-linear-ramp-mask")
def _yarn_linear_ramp_mask(
low: float,
high: float,
dim: int,
dtype: jnp.dtype,
) -> jnp.ndarray:
high = jax.lax.cond(low == high, lambda x: x + 0.001, lambda x: x, high)
linear_func = (jnp.arange(dim, dtype=dtype) - low) / (high - low)
ramp_func = jnp.clip(linear_func, 0, 1)
return ramp_func
@jax.named_scope("easydel-rotary-yarn-get-mscale")
def _yarn_get_mscale(scale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * jnp.log(scale) + 1.0
@jax.named_scope("easydel-rotary-rotate-neox")
def _rotate_neox(x: jnp.ndarray) -> jnp.ndarray:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return jnp.concatenate((-x2, x1), axis=-1)
@jax.named_scope("easydel-rotary-rotate-gptj")
def _rotate_gptj(x: jnp.ndarray) -> jnp.ndarray:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = jnp.stack((-x2, x1), axis=-1)
return x.reshape(x.shape[:-2] + (-1,))
@jax.named_scope("easydel-rotary-apply-rotary-emb")
def _apply_rotary_emb(
x: jnp.ndarray,
cos: jnp.ndarray,
sin: jnp.ndarray,
is_neox_style: bool,
) -> jnp.ndarray:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos[:, :, None].astype(x.dtype)
sin = sin[:, :, None].astype(x.dtype)
assert sin.ndim == x.ndim
if is_neox_style:
x1, x2 = jnp.split(x, 2, axis=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return jnp.concatenate((o1, o2), axis=-1)
else:
return jnp.stack((o1, o2), axis=-1).reshape(x.shape)
AVAILABLE_ROPE_TYPES = {}
[docs]def rope_wraper(type):
def w(rope: RotaryEmbedding):
properties = {k: v for k, v in rope.__dict__.items()}
AVAILABLE_ROPE_TYPES[type] = properties
rope.__str__ = lambda cls: str(cls.__class__.__name__)
rope.__repr__ = lambda cls: repr(cls.__class__.__name__)
rope._type = type
return rope
return w
[docs]@jax.named_scope("easydel-rotary-compute-basic-inv-frequencies")
def compute_basic_inv_frequencies(base: int, rotary_dim: int):
return 1.0 / (base ** (jnp.arange(0, rotary_dim, 2, dtype="f4") / rotary_dim))
[docs]@jax.named_scope("easydel-rotary-compute-yarn-inv-frequencies")
def compute_yarn_inv_frequencies(
base: float,
rotary_dim: int,
beta_fast: float,
beta_slow: float,
max_position_embeddings: int,
scaling_factor: float,
extrapolation_factor: float,
) -> jnp.ndarray:
pos_freqs = base ** (jnp.arange(0, rotary_dim, 2, dtype=jnp.float32) / rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = _yarn_find_correction_range(
low_rot=beta_fast,
high_rot=beta_slow,
dim=rotary_dim,
base=base,
max_position_embeddings=max_position_embeddings,
)
inv_frequencies_mask = (
1 - _yarn_linear_ramp_mask(low, high, rotary_dim // 2, dtype=jnp.float32)
) * extrapolation_factor
inv_frequencies = (
inv_freq_interpolation * (1 - inv_frequencies_mask)
+ inv_freq_extrapolation * inv_frequencies_mask
)
return inv_frequencies
[docs]@jax.named_scope("easydel-rotary-compute-llama3-inv-frequencies")
def compute_llama3_inv_frequencies(
base,
rotary_dim,
low_freq_factor,
high_freq_factor,
orig_max_position,
scaling_factor,
):
inv_freqs = compute_basic_inv_frequencies(base, rotary_dim)
low_freq_wavelen = orig_max_position / low_freq_factor
high_freq_wavelen = orig_max_position / high_freq_factor
wave_len = 2 * jnp.pi / inv_freqs
if low_freq_factor != high_freq_factor:
smooth = (orig_max_position / wave_len - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
else:
smooth = 0
new_freqs = jnp.where(
wave_len < high_freq_wavelen,
inv_freqs,
jnp.where(
wave_len > low_freq_wavelen,
inv_freqs / scaling_factor,
(1 - smooth) * inv_freqs / scaling_factor + smooth * inv_freqs,
),
)
return new_freqs
[docs]@jax.named_scope("easydel-rotary-compute-basic-frequencies")
def compute_basic_frequencies(
base: int,
rotary_dim: int,
max_position_embeddings: int,
):
inv = compute_basic_inv_frequencies(base, rotary_dim)
freqs = jnp.einsum(
"i,j -> ij",
jnp.arange(max_position_embeddings, dtype=jnp.float32),
inv,
)
freqs = jnp.concatenate([jnp.cos(freqs), jnp.sin(freqs)], axis=-1)
return freqs
[docs]@jax.named_scope("easydel-rotary-compute-linear-frequencies")
def compute_linear_frequencies(
base: int,
rotary_dim: int,
max_position_embeddings: int,
scaling_factors: tp.List[float],
):
if not isinstance(scaling_factors, list):
scaling_factors = [scaling_factors]
inv_freq = compute_basic_inv_frequencies(
base=base,
rotary_dim=rotary_dim,
)
cache_list: tp.List[jnp.ndarray] = []
offsets: tp.List[int] = []
for scaling_factor in scaling_factors:
max_len = max_position_embeddings * scaling_factor
t = jnp.arange(max_len, dtype=jnp.float32)
t = t / scaling_factor
freqs = jnp.einsum("i,j -> ij", t, inv_freq)
cache = jnp.concatenate([jnp.cos(freqs), jnp.sin(freqs)], axis=-1)
if not cache_list:
offset = 0
else:
last_offset = offsets[-1]
next_max_len = cache_list[-1].shape[0]
offset = last_offset + next_max_len
offsets.append(offset)
cache_list.append(cache)
assert len(scaling_factors) == len(offsets)
return jnp.concatenate(cache_list, axis=0)
[docs]@jax.named_scope("easydel-rotary-compute-dynamic-frequencies")
def compute_dynamic_frequencies(
base: int,
rotary_dim: int,
max_position_embeddings: int,
scaling_factor: float,
):
max_length = max_position_embeddings * scaling_factor
base = base * (
(scaling_factor * max_length / max_position_embeddings) - (scaling_factor - 1)
) ** (rotary_dim / (rotary_dim - 2))
inv_frequencies = compute_basic_inv_frequencies(base=base, rotary_dim=rotary_dim)
times = jnp.arange(max_length, dtype=jnp.float32)
frequencies = jnp.einsum("i,j -> ij", times, inv_frequencies)
return jnp.concatenate([jnp.cos(frequencies), jnp.sin(frequencies)], -1)
[docs]@jax.named_scope("easydel-rotary-compute-yarn-frequencies")
def compute_yarn_frequencies(
base: float,
rotary_dim: int,
beta_fast: float,
beta_slow: float,
max_position_embeddings: int,
scaling_factor: float,
extrapolation_factor: float,
attn_factor: float,
) -> jnp.ndarray:
inv_freq = compute_yarn_inv_frequencies(
base=base,
rotary_dim=rotary_dim,
beta_fast=beta_fast,
beta_slow=beta_slow,
max_position_embeddings=max_position_embeddings,
scaling_factor=scaling_factor,
extrapolation_factor=extrapolation_factor,
)
t = jnp.arange(max_position_embeddings * scaling_factor, dtype=jnp.float32)
freqs = jnp.einsum("i,j -> ij", t, inv_freq)
mscale = _yarn_get_mscale(scaling_factor) * attn_factor
cos = jnp.cos(freqs) * mscale
sin = jnp.sin(freqs) * mscale
return jnp.concatenate([cos, sin], axis=-1)
[docs]@jax.named_scope("easydel-rotary-compute-phi3-frequencies")
def compute_phi3_frequencies(
base,
head_size,
rotary_dim,
max_position_embeddings,
original_max_position_embeddings,
short_factor,
long_factor,
):
if rotary_dim != head_size:
raise ValueError(f"rotary_dim != head_size ({rotary_dim}!={head_size})")
if max_position_embeddings > original_max_position_embeddings:
ext_factors = jnp.array(long_factor, dtype=jnp.float32)
else:
ext_factors = jnp.array(short_factor, dtype=jnp.float32)
inv_freq_shape = (
jnp.arange(0, head_size, 2, dtype=jnp.int32).astype(jnp.float32) / head_size
)
inv_freq = 1.0 / (ext_factors * (base**inv_freq_shape))
inv_freq_expanded = jnp.expand_dims(inv_freq, (0, 2)).astype(jnp.float32)
position_ids = jnp.arange(max_position_embeddings, dtype=jnp.int32).reshape(1, -1)
position_ids_expanded = jnp.expand_dims(position_ids, 1).astype(jnp.float32)
freqs = (inv_freq_expanded @ position_ids_expanded).swapaxes(1, 2)
emb = jnp.concatenate((freqs, freqs), axis=-1)
scale = max_position_embeddings / original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(
1 + math.log(scale) / math.log(original_max_position_embeddings)
)
cos = jnp.cos(emb) * scaling_factor
sin = jnp.sin(emb) * scaling_factor
return jnp.concatenate([cos, sin], axis=-1)
[docs]@jax.named_scope("easydel-rotary-compute-llama3-frequencies")
def compute_llama3_frequencies(
base,
rotary_dim,
low_freq_factor,
high_freq_factor,
scaling_factor,
max_position_embeddings: int,
):
inv = compute_llama3_inv_frequencies(
base,
rotary_dim,
low_freq_factor,
high_freq_factor,
max_position_embeddings,
scaling_factor,
)
freqs = jnp.einsum(
"i,j -> ij",
jnp.arange(max_position_embeddings, dtype=jnp.float32),
inv,
)
freqs = jnp.concatenate([jnp.cos(freqs), jnp.sin(freqs)], axis=-1)
return freqs
[docs]@jax.named_scope("easydel-rotary-compute-deepseek-frequencies")
def compute_deepseek_frequencies(
base,
rotary_dim,
scaling_factor,
extrapolation_factor,
beta_fast,
beta_slow,
max_position_embeddings,
mscale,
mscale_all_dim,
attn_factor,
) -> jnp.ndarray:
pos_freqs = base ** (jnp.arange(0, rotary_dim, 2, dtype=jnp.float32) / rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = _yarn_find_correction_range(
beta_fast,
beta_slow,
rotary_dim,
base,
max_position_embeddings,
)
inv_freq_mask = (
1 - _yarn_linear_ramp_mask(low, high, rotary_dim // 2, dtype=jnp.float32)
) * extrapolation_factor
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)
t = jnp.arange(
max_position_embeddings * scaling_factor,
dtype=jnp.float32,
)
freqs = jnp.einsum("i,j -> ij", t, inv_freq)
mscale = (
yarn_get_mscale(scaling_factor, mscale)
/ yarn_get_mscale(scaling_factor, mscale_all_dim)
* attn_factor
)
return jnp.concatenate([jnp.cos(freqs) * mscale, jnp.sin(freqs) * mscale], axis=-1)
[docs]@jax.named_scope("easydel-rotary-apply-basic-rope")
def apply_basic_rope(
query: jax.Array,
key: jax.Array,
positions: jax.Array,
frequencies: jax.Array,
rotary_dim: int,
is_neox_style: bool,
offsets: jax.Array = None,
dtype: jnp.dtype = jnp.float32,
):
if offsets is not None:
positions = positions + offsets
cos, sin = jnp.split(frequencies[positions], 2, -1)
if rotary_dim != query.shape[-1]:
query_rot = _apply_rotary_emb(query[..., :rotary_dim], cos, sin, is_neox_style)
query = jnp.concatenate((query_rot, query[..., rotary_dim:]), axis=-1)
key_rot = _apply_rotary_emb(key[..., :rotary_dim], cos, sin, is_neox_style)
key = jnp.concatenate((key_rot, key[..., rotary_dim:]), axis=-1)
return query.astype(dtype), key.astype(dtype)
else:
query = _apply_rotary_emb(query, cos, sin, is_neox_style)
key = _apply_rotary_emb(key, cos, sin, is_neox_style)
return query.astype(dtype), key.astype(dtype)
[docs]@jax.named_scope("easydel-rotary-apply-phi3-rope")
def apply_phi3_rope(
query,
key,
positions,
frequencies,
offsets: jax.Array = None,
dtype: jnp.dtype = jnp.float32,
):
positions = positions
if offsets is not None:
positions = positions + offsets
emb = frequencies[0, positions]
cos, sin = jnp.split(emb, 2, axis=-1)
cos = jnp.expand_dims(cos, 2)
sin = jnp.expand_dims(sin, 2)
with jax.default_matmul_precision("float32"):
query_rot = query * cos + _rotate_neox(query) * sin
key_rot = key * cos + _rotate_neox(key) * sin
return query_rot.astype(dtype), key_rot.astype(dtype)
[docs]@rope_wraper("default")
class RotaryEmbedding(nn.Module):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: jnp.dtype,
):
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
@jax.named_scope("easydel-rope-embedding")
def __call__(
self,
positions: jnp.ndarray,
query: jnp.ndarray,
key: jnp.ndarray,
offsets: tp.Optional[jnp.ndarray] = None,
frequencies: tp.Optional[jnp.ndarray] = None,
) -> tp.Tuple[jnp.ndarray, jnp.ndarray]:
"""__call__ pass for the rotary embedding."""
with jax.ensure_compile_time_eval():
if frequencies is None:
frequencies = compute_basic_frequencies(
base=self.base,
rotary_dim=self.rotary_dim,
max_position_embeddings=self.max_position_embeddings,
)
if hasattr(frequencies, "value"):
frequencies = frequencies.value
return apply_basic_rope(
query=query,
key=key,
positions=positions,
frequencies=frequencies,
rotary_dim=self.rotary_dim,
is_neox_style=self.is_neox_style,
offsets=offsets,
dtype=self.dtype,
)
[docs]@rope_wraper("linear")
class LinearScalingRotaryEmbedding(RotaryEmbedding):
def __init__(
self,
scaling_factors: tp.Union[tp.List[float], float],
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: jnp.dtype,
):
super().__init__(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position_embeddings,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
self.scaling_factors = scaling_factors
@jax.named_scope("easydel-rope-linear-scaling")
def __call__(
self,
positions: jnp.ndarray,
query: jnp.ndarray,
key: jnp.ndarray,
offsets: tp.Optional[jnp.ndarray] = None,
frequencies: tp.Optional[jnp.ndarray] = None,
) -> tp.Tuple[jnp.ndarray, jnp.ndarray]:
"""__call__ pass for the rotary embedding."""
with jax.ensure_compile_time_eval():
if frequencies is None:
frequencies = compute_linear_frequencies(
base=self.base,
rotary_dim=self.rotary_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factors=self.scaling_factors,
)
if hasattr(frequencies, "value"):
frequencies = frequencies.value
return apply_basic_rope(
query=query,
key=key,
positions=positions,
frequencies=frequencies,
rotary_dim=self.rotary_dim,
is_neox_style=self.is_neox_style,
offsets=offsets,
dtype=self.dtype,
)
[docs]@rope_wraper("dynamic")
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling."""
def __init__(
self,
scaling_factor: tp.Union[tp.List[float], float],
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: jnp.dtype,
):
super().__init__(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position_embeddings,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
self.scaling_factor = scaling_factor
@jax.named_scope("easydel-rope-dynamic-ntk-scaling")
def __call__(
self,
positions: jnp.ndarray,
query: jnp.ndarray,
key: jnp.ndarray,
offsets: tp.Optional[jnp.ndarray] = None,
frequencies: tp.Optional[jnp.ndarray] = None,
) -> tp.Tuple[jnp.ndarray, jnp.ndarray]:
"""__call__ pass for the rotary embedding."""
with jax.ensure_compile_time_eval():
if frequencies is None:
frequencies = compute_dynamic_frequencies(
base=self.base,
rotary_dim=self.rotary_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=self.scaling_factor,
)
if hasattr(frequencies, "value"):
frequencies = frequencies.value
return apply_basic_rope(
query=query,
key=key,
positions=positions,
frequencies=frequencies,
rotary_dim=self.rotary_dim,
is_neox_style=self.is_neox_style,
offsets=offsets,
dtype=self.dtype,
)
[docs]@rope_wraper("yarn")
class YaRNScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: jnp.dtype,
scaling_factor: tp.Union[float, int] = 1.0,
extrapolation_factor: float = 1.0,
attn_factor: float = 1.0,
beta_fast: int = 32,
beta_slow: int = 1,
):
super().__init__(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position_embeddings,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
@jax.named_scope("easydel-rope-yarn-scaling")
def __call__(
self,
positions: jnp.ndarray,
query: jnp.ndarray,
key: jnp.ndarray,
offsets: tp.Optional[jnp.ndarray] = None,
frequencies: tp.Optional[jnp.ndarray] = None,
) -> tp.Tuple[jnp.ndarray, jnp.ndarray]:
"""__call__ pass for the rotary embedding."""
with jax.ensure_compile_time_eval():
if frequencies is None:
frequencies = compute_yarn_frequencies(
base=self.base,
rotary_dim=self.rotary_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=self.scaling_factor,
beta_fast=self.beta_fast,
beta_slow=self.beta_slow,
extrapolation_factor=self.extrapolation_factor,
attn_factor=self.attn_factor,
)
if hasattr(frequencies, "value"):
frequencies = frequencies.value
return apply_basic_rope(
query=query,
key=key,
positions=positions,
frequencies=frequencies,
rotary_dim=self.rotary_dim,
is_neox_style=self.is_neox_style,
offsets=offsets,
dtype=self.dtype,
)
[docs]@rope_wraper("longrope")
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
original_max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: jnp.dtype,
short_factor: tp.List[float],
long_factor: tp.List[float],
):
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
self.short_factor = short_factor
self.long_factor = long_factor
@jax.named_scope("easydel-rope-phi3-long")
def __call__(
self,
positions: jnp.ndarray,
query: jnp.ndarray,
key: jnp.ndarray,
offsets: tp.Optional[jnp.ndarray] = None,
frequencies: tp.Optional[jnp.ndarray] = None,
) -> tp.Tuple[jnp.ndarray, jnp.ndarray]:
"""__call__ pass for the rotary embedding."""
with jax.ensure_compile_time_eval():
if frequencies is None:
frequencies = compute_phi3_frequencies(
base=self.base,
head_size=self.head_size,
rotary_dim=self.rotary_dim,
max_position_embeddings=self.max_position_embeddings,
original_max_position_embeddings=self.original_max_position_embeddings,
short_factor=self.short_factor,
long_factor=self.long_factor,
)
if hasattr(frequencies, "value"):
frequencies = frequencies.value
return apply_phi3_rope(
query=query,
key=key,
positions=positions,
frequencies=frequencies,
offsets=offsets,
dtype=self.dtype,
)
[docs]@rope_wraper("llama3")
class Llama3RotaryEmbedding(RotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: jnp.dtype,
scaling_factor: float,
low_freq_factor: float,
high_freq_factor: float,
orig_max_position: int,
):
super().__init__(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position_embeddings,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
self.scaling_factor = scaling_factor
self.low_freq_factor = low_freq_factor
self.high_freq_factor = high_freq_factor
self.orig_max_position = orig_max_position
@jax.named_scope("easydel-rope-llama3")
def __call__(
self,
positions: jnp.ndarray,
query: jnp.ndarray,
key: jnp.ndarray,
offsets: tp.Optional[jnp.ndarray] = None,
frequencies: tp.Optional[jnp.ndarray] = None,
) -> tp.Tuple[jnp.ndarray, jnp.ndarray]:
"""__call__ pass for the rotary embedding."""
with jax.ensure_compile_time_eval():
if frequencies is None:
frequencies = compute_llama3_frequencies(
base=self.base,
rotary_dim=self.rotary_dim,
low_freq_factor=self.low_freq_factor,
high_freq_factor=self.high_freq_factor,
scaling_factor=self.scaling_factor,
max_position_embeddings=self.orig_max_position,
)
if hasattr(frequencies, "value"):
frequencies = frequencies.value
return apply_basic_rope(
query=query,
key=key,
positions=positions,
frequencies=frequencies,
rotary_dim=self.rotary_dim,
is_neox_style=self.is_neox_style,
offsets=offsets,
dtype=self.dtype,
)
[docs]def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * jnp.log(scale) + 1.0
[docs]@rope_wraper("deepseek_yarn")
class DeepseekScalingRotaryEmbedding(nn.Module):
"""RotaryEmbedding extended with YaRN method."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: jnp.dtype,
scaling_factor: float,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
):
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.mscale = mscale
self.mscale_all_dim = mscale_all_dim
@jax.named_scope("easydel-rope-deepseek")
def __call__(
self,
positions: jnp.ndarray,
query: jnp.ndarray,
key: jnp.ndarray,
offsets: tp.Optional[jnp.ndarray] = None,
frequencies: tp.Optional[jnp.ndarray] = None,
) -> tp.Tuple[jnp.ndarray, jnp.ndarray]:
if frequencies is None:
frequencies = compute_deepseek_frequencies(
self.base,
self.rotary_dim,
self.scaling_factor,
self.extrapolation_factor,
self.beta_fast,
self.beta_slow,
self.max_position_embeddings,
self.mscale,
self.mscale_all_dim,
self.attn_factor,
)
cos, sin = jnp.split(frequencies[positions], 2, -1)
if offsets is not None:
positions += offsets
query_rot = query[..., : self.rotary_dim]
key_rot = key[..., : self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim :]
key_pass = key[..., self.rotary_dim :]
target_sc_shape = (query.shape[0], -1, 1, self.rotary_dim)
if self.is_neox_style:
cos = cos.repeat(2, axis=1).reshape(target_sc_shape)
sin = sin.repeat(2, axis=1).reshape(target_sc_shape)
else:
cos = cos.repeat_interleave(2, axis=1).reshape(target_sc_shape)
sin = sin.repeat_interleave(2, axis=1).reshape(target_sc_shape)
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if self.rotary_dim < self.head_size:
query = jnp.concatenate((query_rot, query_pass), axis=-1)
key = jnp.concatenate((key_rot, key_pass), axis=-1)
else:
query = query_rot
key = key_rot
return query, key
[docs]def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: tp.Optional[tp.Dict[str, tp.Any]] = None,
dtype: tp.Optional[jnp.dtype] = None,
partial_rotary_factor: float = 1.0,
) -> RotaryEmbedding:
if dtype is None:
dtype = jnp.float32 # Default JAX dtype
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
if rope_scaling is None:
rotary_emb = RotaryEmbedding(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
else:
scaling_type = rope_scaling["rope_type"]
if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
scaling_factor=scaling_factor,
low_freq_factor=low_freq_factor,
high_freq_factor=high_freq_factor,
orig_max_position=original_max_position,
)
elif scaling_type == "default":
rotary_emb = RotaryEmbedding(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
rotary_emb = LinearScalingRotaryEmbedding(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position,
base=base,
is_neox_style=is_neox_style,
scaling_factors=scaling_factor,
dtype=dtype,
)
elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position,
base=base,
is_neox_style=is_neox_style,
scaling_factor=scaling_factor,
dtype=dtype,
)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow")
}
rotary_emb = YaRNScalingRotaryEmbedding(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=original_max_position,
base=base,
is_neox_style=is_neox_style,
scaling_factor=scaling_factor,
dtype=dtype,
**extra_kwargs,
)
elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
)
}
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=original_max_position,
base=base,
is_neox_style=is_neox_style,
scaling_factor=scaling_factor,
dtype=dtype,
**extra_kwargs,
)
elif scaling_type == "longrope":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position,
original_max_position_embeddings=original_max_position,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
short_factor=short_factor,
long_factor=long_factor,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
return rotary_emb
[docs]@partial(
jax.jit,
static_argnames=[
"head_size",
"rotary_dim",
"max_position",
"base",
"rope_scaling",
"partial_rotary_factor",
],
)
def get_frequencies(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
rope_scaling: tp.Optional[tp.Dict[str, tp.Any]] = None,
partial_rotary_factor: float = 1.0,
) -> jax.Array:
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
if rope_scaling is None:
frequencies = compute_basic_frequencies(
base=base,
rotary_dim=rotary_dim,
max_position_embeddings=max_position,
)
else:
scaling_type = rope_scaling["rope_type"]
if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
frequencies = compute_llama3_frequencies(
base=base,
rotary_dim=rotary_dim,
low_freq_factor=low_freq_factor,
high_freq_factor=high_freq_factor,
scaling_factor=scaling_factor,
max_position_embeddings=original_max_position,
)
elif scaling_type == "default":
frequencies = compute_basic_frequencies(
base=base,
rotary_dim=rotary_dim,
max_position_embeddings=max_position,
)
elif scaling_type == "linear":
scaling_factors = rope_scaling["factor"]
frequencies = compute_linear_frequencies(
base=base,
rotary_dim=rotary_dim,
max_position_embeddings=max_position,
scaling_factors=scaling_factors,
)
elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
frequencies = compute_dynamic_frequencies(
rotary_dim=rotary_dim,
max_position_embeddings=max_position,
base=base,
scaling_factor=scaling_factor,
)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow")
}
frequencies = compute_yarn_frequencies(
base=base,
rotary_dim=rotary_dim,
beta_fast=extra_kwargs["beta_fast"],
beta_slow=extra_kwargs["beta_slow"],
max_position_embeddings=original_max_position,
scaling_factor=scaling_factor,
extrapolation_factor=extra_kwargs["extrapolation_factor"],
attn_factor=extra_kwargs["attn_factor"],
)
elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
)
}
frequencies = compute_deepseek_frequencies(
base,
rotary_dim,
scaling_factor,
extra_kwargs["extrapolation_factor"],
extra_kwargs["beta_fast"],
extra_kwargs["beta_slow"],
original_max_position,
extra_kwargs["mscale"],
extra_kwargs["mscale_all_dim"],
extra_kwargs["attn_factor"],
)
elif scaling_type == "longrope":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
extra_kwargs = {
k: v for k, v in rope_scaling.items() if k in ("short_mscale", "long_mscale")
}
frequencies = compute_phi3_frequencies(
base=base,
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position,
original_max_position_embeddings=original_max_position,
short_factor=short_factor,
long_factor=long_factor,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
return frequencies
# Example usage
if __name__ == "__main__":
head_size = 64
rotary_dim = 64
max_position = 2048
base = 10000
is_neox_style = True
dtype = jnp.float32
rope_scaling = {
"rope_type": "yarn",
"factor": 2.0,
"original_max_position_embeddings": 1024,
"extrapolation_factor": 1.0,
"attn_factor": 1.0,
"beta_fast": 32,
"beta_slow": 1,
}
rope = get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling,
dtype,
)
freq = get_frequencies(
head_size,
rotary_dim,
max_position,
base,
rope_scaling,
)
[docs]@chex.dataclass
class RopeConfig:
"""Configuration class for RoPE (Rotary Position Embedding) parameters."""
rope_type: str = "default"
factor: tp.Optional[float] = None
low_freq_factor: tp.Optional[float] = None
high_freq_factor: tp.Optional[float] = None
original_max_position_embeddings: tp.Optional[int] = None
long_factor: tp.Optional[float] = None
short_factor: tp.Optional[float] = None
long_mscale: tp.Optional[float] = None
short_mscale: tp.Optional[float] = None
[docs] @classmethod
def from_dict(cls, config_dict: tp.Dict[str, tp.Any]) -> "RopeConfig":
"""Create a RopeConfig instance from a dictionary."""
return cls(
rope_type=config_dict.get("rope_type") or config_dict.get("type", "default"),
factor=config_dict.get("factor"),
low_freq_factor=config_dict.get("low_freq_factor"),
high_freq_factor=config_dict.get("high_freq_factor"),
original_max_position_embeddings=config_dict.get(
"original_max_position_embeddings"
),
long_factor=config_dict.get("long_factor"),
short_factor=config_dict.get("short_factor"),
long_mscale=config_dict.get("long_mscale"),
short_mscale=config_dict.get("short_mscale"),
)
[docs] def to_dict(self) -> tp.Dict[str, tp.Any]:
"""Convert the config to a dictionary, excluding None values."""
from easydel.utils.compiling_utils import hash_fn
class rope_scaling(dict):
__hash__ = hash_fn
scale = rope_scaling({k: v for k, v in self.__dict__.items() if v is not None})
return scale