# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import jax
import jax.numpy as jnp
from eformer import common_types
from eformer.escale import apply_logical_sharding
from ejkernel.types import MaskInfo
from flax import nnx as nn
from jax.ad_checkpoint import checkpoint_name
from jaxtyping import Array, Bool, Float, Int
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.modeling_outputs import BaseModelOutput, DecoderLayerOutput
from easydel.infra.utils import ACT2FN, auto_remat, block_wise_ffn, get_dot_general_by_bits
from easydel.layers.attention_unified import UnifiedAttention
from easydel.layers.base_modules import BaseCausalLMModule, BaseSequenceClassificationModule
from easydel.layers.caching import (
RaggedPagesCache,
RaggedPagesCacheView,
RaggedPagesMetadata,
TransformerCache,
TransformerCacheView,
TransformerMetadata,
)
from easydel.layers.linear import ColumnParallelLinear, RowParallelLinear
from easydel.layers.norms import RMSNorm
from .llama_configuration import LlamaConfig
[docs]class LlamaMLP(nn.Module):
"""Multi-Layer Perceptron module for Llama models.
Implements the feedforward network with SwiGLU activation function
for enhanced representation learning in Llama architecture.
"""
def __init__(
self,
config: LlamaConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
column_parallel_linear = partial(
ColumnParallelLinear,
dtype=dtype,
param_dtype=param_dtype,
use_bias=self.config.mlp_bias,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
precision=precision,
rngs=rngs,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
row_parallel_linear = partial(
RowParallelLinear,
dtype=dtype,
param_dtype=param_dtype,
use_bias=self.config.mlp_bias,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
precision=precision,
rngs=rngs,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
self.gate_proj = column_parallel_linear(config.hidden_size, config.intermediate_size)
self.down_proj = row_parallel_linear(config.intermediate_size, config.hidden_size)
self.up_proj = column_parallel_linear(config.hidden_size, config.intermediate_size)
self.dropout = nn.Dropout(rate=self.config.resid_pdrop, rngs=rngs)
self.act_fn = ACT2FN[self.config.hidden_act]
def __call__(
self, hidden_states: Float[Array, "batch seq_len hidden_dim"]
) -> Float[Array, "batch seq_len hidden_dim"]:
"""Apply SwiGLU feedforward transformation.
Args:
hidden_states: Input tensor [batch, seq_len, hidden_dim]
Returns:
Transformed hidden states [batch, seq_len, hidden_dim]
"""
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
gate = checkpoint_name(self.act_fn(self.gate_proj(hidden_states)), "mlp_gate")
up = checkpoint_name(self.up_proj(hidden_states), "mlp_up")
hidden_states = checkpoint_name(self.down_proj(gate * up), "mlp_down")
hidden_states = self.dropout(hidden_states)
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
return checkpoint_name(hidden_states, "mlp_output")
[docs]class LlamaAttention(UnifiedAttention):
"""Multi-head attention layer with RoPE embeddings for Llama models."""
def __init__(
self,
config: LlamaConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
layer_idx: int,
):
"""Initialize attention layer with unified attention backend."""
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
layer_idx=layer_idx,
attention_type="standard",
causal=True,
)
[docs]class LlamaDecoderLayer(nn.Module):
"""Single decoder layer for Llama models.
Combines multi-head attention and feedforward networks with
RMS normalization and residual connections.
"""
def __init__(
self,
config: LlamaConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
layer_idx: int,
):
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
attn_block = LlamaAttention
mlp_block = LlamaMLP
attn_block, mlp_block = auto_remat(
attn_block,
mlp_block,
policy=config.gradient_checkpointing,
save_names=config.gradient_checkpointing_targets,
exclude_names=config.gradient_checkpointing_targets,
)
self.self_attn = attn_block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
layer_idx=layer_idx,
)
self.mlp = mlp_block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.input_layernorm = RMSNorm(
dim=config.hidden_size,
eps=config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.post_attention_layernorm = RMSNorm(
dim=config.hidden_size,
eps=config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
def __call__(
self,
hidden_states: Float[Array, "batch seq_len hidden_dim"],
mask_info: MaskInfo | None,
position_ids: Int[Array, "batch seq_len"],
mode: common_types.RUNTIME_MODE_TYPES, # type:ignore
cache_view: TransformerCacheView | RaggedPagesCacheView | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
output_attentions: bool = False,
frequencies: Float[Array, "seq_len head_dim"] | None = None,
) -> DecoderLayerOutput:
attn_outputs = self.self_attn(
self.input_layernorm(hidden_states),
mask_info,
position_ids,
mode,
cache_view,
cache_metadata,
output_attentions,
frequencies,
)
hidden_states = checkpoint_name(hidden_states + attn_outputs.attention_output, "residual")
feed_forward_input = self.post_attention_layernorm(hidden_states)
if self.config.use_scan_mlp:
feed_forward_hidden_states = block_wise_ffn(
self.mlp,
feed_forward_input,
self.config.scan_mlp_chunk_size,
)
else:
feed_forward_hidden_states = self.mlp(feed_forward_input)
hidden_states = checkpoint_name(hidden_states + feed_forward_hidden_states, "residual")
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
hidden_states = checkpoint_name(hidden_states, "layer_output")
return DecoderLayerOutput(
hidden_states=hidden_states,
attention_weight=attn_outputs.attention_weight,
cache_view=attn_outputs.cache_view,
)
[docs]@register_module(TaskType.BASE_MODULE, config=LlamaConfig, model_type="llama")
class LlamaModel(EasyDeLBaseModule):
"""Llama model implementation.
This implements the Llama language model architecture, utilizing transformer blocks
with RMSNorm, rotary position embeddings, and a specific attention mechanism.
Attributes:
config (LlamaConfig): Configuration for the model.
dtype (jnp.dtype): Data type for computations.
param_dtype (jnp.dtype): Data type for parameters.
precision: Precision setting for JAX operations.
"""
def __init__(
self,
config: LlamaConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
embed_block = auto_remat(
nn.Embed,
policy=config.gradient_checkpointing,
save_names=config.gradient_checkpointing_targets,
exclude_names=config.gradient_checkpointing_targets,
)
self.embed_tokens = embed_block(
num_embeddings=self.config.vocab_size,
features=self.config.hidden_size,
dtype=dtype,
param_dtype=param_dtype,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
rngs=rngs,
)
self.dropout = nn.Dropout(rate=self.config.embd_pdrop, rngs=rngs)
self.layers = [
LlamaDecoderLayer(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
layer_idx=layer_idx,
)
for layer_idx in range(self.config.num_hidden_layers)
]
self.norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
def __call__(
self,
input_ids: Int[Array, "batch seq_len"] | None = None,
inputs_embeds: Float[Array, "batch seq_len hidden_dim"] | None = None,
attention_mask: Bool[Array, "batch seq_len"] | None = None,
mask_info: MaskInfo | None = None,
position_ids: Int[Array, "batch seq_len"] | None = None,
mode: common_types.RUNTIME_MODE_TYPES | None = None, # type:ignore
past_key_values: TransformerCache | RaggedPagesCache | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
) -> BaseModelOutput:
"""Forward pass through the Llama model.
Args:
input_ids (chex.Array, optional): Input token IDs, shape (batch_size, sequence_length).
inputs_embeds (chex.Array, optional): Input embeddings, shape (batch_size, sequence_length, hidden_size).
attention_mask (chex.Array, optional): Mask to avoid attention on padding tokens.
position_ids (chex.Array, optional): Indices of positions of each input sequence token.
segment_ids (chex.Array, optional): Segment token indices for segment embeddings.
past_key_values (TransformerCache | RaggedPagesCache, optional): Cache containing
precomputed key/value states.
cache_metadata (TransformerMetadata | RaggedPagesMetadata, optional): Metadata for cache handling.
output_attentions (bool, optional): Whether to return attention weights.
output_hidden_states (bool, optional): Whether to return hidden states of all layers.
Returns:
Union[BaseModelOutput, Tuple]: Model outputs (last hidden state, optional hidden states, optional attentions)
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = checkpoint_name(self.embed_tokens(input_ids.astype("i4")), "embeddings")
sequence_length = inputs_embeds.shape[1]
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
assert sequence_length <= self.config.max_position_embeddings, (
f"Maximum Position Embedding Reached ! "
f"(Excepted <= {self.config.max_position_embeddings} got {sequence_length})"
)
mask_info = MaskInfo.dynamic_init(
mask_info=mask_info,
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
if position_ids is None:
position_ids = mask_info.q_position_ids
hidden_states = self.dropout(inputs_embeds)
if mode is None:
mode = (
common_types.MODE_DECODE
if sequence_length == 1 and past_key_values is not None
else common_types.MODE_TRAIN
)
if past_key_values is None:
past_key_values = TransformerCache.init_empty(len(self.layers))
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
for idx, block in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = block(
hidden_states=hidden_states,
mask_info=mask_info,
position_ids=position_ids,
mode=mode,
cache_view=past_key_values.views[idx],
cache_metadata=cache_metadata,
output_attentions=output_attentions,
frequencies=self.frequencies,
)
hidden_states = layer_outputs.hidden_states
if output_attentions:
all_attentions += (layer_outputs.attention_weight,)
past_key_values[idx] = layer_outputs.cache_view
hidden_states = self.norm(hidden_states)
hidden_states = checkpoint_name(hidden_states, "model_output")
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
past_key_values=past_key_values,
)
[docs] def get_encoder(self):
"""
Returns the encoder part of the model's graph definition.
Decoder-Only models don't have an encoder.
"""
raise NotImplementedError("This is a decoder-only model and does not have an encoder.")
[docs] def get_decoder(self):
"""
Returns the decoder part of the model's graph definition.
"""
return self
[docs] def get_lm_head(self):
"""
Returns the language model head of the module.
Base Models don't have a Language Model Head.
"""
raise NotImplementedError("The base model does not have a language model head.")
[docs] def get_embedding(self):
"""
Returns the embedding layer of the module.
"""
return self.embed_tokens
[docs]@register_module(TaskType.CAUSAL_LM, config=LlamaConfig, model_type="llama")
class LlamaForCausalLM(BaseCausalLMModule[LlamaModel, LlamaConfig]):
"""Llama model with a language modeling head for causal language modeling tasks.
This model is a transformer-based language model with causal attention masks
applied to perform autoregressive language generation.
Attributes:
config (LlamaConfig): Configuration for the model.
dtype (jnp.dtype): Data type for computations (default is jnp.bfloat16).
param_dtype (jnp.dtype): Data type for parameters (default is jnp.bfloat16).
precision: Precision setting for JAX operations.
"""
_task_type = TaskType.CAUSAL_LM
_model_type = "llama"
_config_class = LlamaConfig
def __init__(
self,
config: LlamaConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
base_model_class=LlamaModel,
base_model_name="model",
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
lm_head_bias=False,
)
[docs]@register_module(TaskType.SEQUENCE_CLASSIFICATION, config=LlamaConfig, model_type="llama")
class LlamaForSequenceClassification(BaseSequenceClassificationModule[LlamaModel, LlamaConfig]):
"""Llama model for sequence classification tasks.
This class extends the base Llama model by adding a linear classification head
to perform sequence classification tasks such as sentiment analysis or text classification.
Attributes:
config (LlamaConfig): Configuration for the model.
dtype (jnp.dtype): Data type for computations.
param_dtype (jnp.dtype): Data type for parameters.
precision: Precision setting for JAX operations.
"""
_task_type = TaskType.SEQUENCE_CLASSIFICATION
_model_type = "llama"
_config_class = LlamaConfig
def __init__(
self,
config: LlamaConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
base_model_class=LlamaModel,
base_model_name="model",
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
pooling_strategy="last",
score_head_bias=False,
)