Source code for easydel.modules.phi3.modeling_phi3_flax

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import functools
import typing as tp

import chex
import jax.lax
from chex import Array
from flax import nnx as nn
from jax import numpy as jnp

from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import register_module
from easydel.infra.modeling_outputs import (
	FlaxBaseModelOutput,
	FlaxCausalLMOutput,
)
from easydel.infra.utils import (
	ACT2FN,
	auto_remat,
	block_wise_ffn,
	control_mlp_sharding,
	get_dot_general_by_bits,
)
from easydel.layers.attention import FlaxAttentionModule, FlexibleAttentionModule
from easydel.layers.caching import TransformerCache, TransformerCacheView
from easydel.layers.norms import RMSNorm as RMSNorm
from easydel.modules.phi3.phi3_configuration import Phi3Config as Phi3Config


[docs]class Phi3MLP(nn.Module): def __init__( self, config: Phi3Config, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.config = config self.dtype = dtype self.param_dtype = param_dtype self.precision = precision linear_class = functools.partial( nn.Linear, dtype=dtype, param_dtype=param_dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(config.initializer_range), precision=precision, rngs=rngs, **get_dot_general_by_bits(config.bits, config.easy_method), ) self.gate_up_proj = linear_class( config.hidden_size, 2 * config.intermediate_size, rngs=rngs, ) self.down_proj = linear_class( config.intermediate_size, config.hidden_size, rngs=rngs, ) self.activation_fn = ACT2FN[self.config.hidden_act] def __call__(self, hidden_states: Array) -> Array: hidden_states = control_mlp_sharding(hidden_states, self.config.partition_axis) up_states = self.gate_up_proj(hidden_states) gate, up_states = jnp.split(up_states, 2, axis=-1) up_states = up_states * self.activation_fn(gate) return self.down_proj(up_states)
[docs]class Phi3Attention(FlaxAttentionModule): def __init__( self, config: Phi3Config, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__(config=config) self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.rngs = rngs self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.original_max_position_embeddings = config.original_max_position_embeddings self.is_causal = True if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) linear_class = functools.partial( nn.Linear, use_bias=False, precision=precision, dtype=dtype, param_dtype=param_dtype, kernel_init=jax.nn.initializers.normal(config.initializer_range), **get_dot_general_by_bits(config.bits, config.easy_method), ) op_size = self.num_heads * self.head_dim + 2 * ( self.num_key_value_heads * self.head_dim ) self.o_proj = linear_class( self.num_heads * self.head_dim, self.hidden_size, rngs=rngs, ) self.qkv_proj = linear_class( self.hidden_size, op_size, rngs=rngs, ) self.attention_performer = FlexibleAttentionModule( base_config=config, softmax_scale=self.head_dim**-0.5, dropout_prob=config.attention_dropout, ) self.rotary = self.config.get_basic_rope( self.dtype, head_size=self.head_dim, base=config.rope_theta, is_neox_style=True, ) def __call__( self, hidden_states: chex.Array, attention_mask: chex.Array, position_ids: chex.Array, causal_mask: chex.Array, cache_view: tp.Optional[TransformerCacheView] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: bool = False, fcm_mask: tp.Optional[chex.Array] = None, frequencies: tp.Optional[chex.Array] = None, ): """ Forward pass of the attention module. Args: hidden_states (chex.Array): Input hidden states. attention_mask (chex.Array): Mask to apply on the attention scores. position_ids (chex.Array): Position indices for the tokens. causal_mask (chex.Array): Causal mask for ensuring autoregressive behavior. segment_ids (tp.Optional[chex.Array]): Segment IDs for segment-based attention (optional). deterministic (bool): If True, disables dropout for deterministic behavior. init_cache (bool): If True, initializes cache for caching keys and values. output_attentions (bool): If True, outputs attention weights alongside the hidden states. fcm_mask (tp.Optional[chex.Array]): fcm mask to be combined with attn mask and causal mask. Returns: tp.Tuple[chex.Array, chex.Array]: A tuple containing the attention output and the attention weights. """ batch_size, sequence_length = hidden_states.shape[:2] qkv = self.qkv_proj(hidden_states) query_pos = self.num_heads * self.head_dim query_states = qkv[..., :query_pos] key_states = qkv[ ..., query_pos : query_pos + self.num_key_value_heads * self.head_dim ] value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] query_states = query_states.reshape( batch_size, sequence_length, self.config.num_attention_heads, self.head_dim, ) key_states = key_states.reshape( batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim, ) value_states = value_states.reshape( batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim, ) query_states, key_states = self.rotary( query=query_states, key=key_states, positions=position_ids, frequencies=frequencies, ) ( key_states, value_states, attention_mask, init_attention_bias, ) = self.concatenate( query=query_states, key=key_states, cache_view=cache_view, value=value_states, attention_mask=attention_mask, causal_mask=causal_mask, fcm_mask=fcm_mask, ) attentions = self.attention_performer.forward( query_states=query_states, key_states=key_states, value_states=value_states, bias=None, init_bias=init_attention_bias, attention_mask=attention_mask, segment_ids=segment_ids, causal=True, dropout_rng=self.rngs.params(), ) attn_output = self.shard_attention_prod( self._merge_heads(attentions.attention_outputs) ) attn_output = self.o_proj(attn_output) outputs = ( (attn_output, attentions.attention_weights) if output_attentions else (attn_output,) ) return outputs
[docs]class Phi3DecoderLayer(nn.Module): def __init__( self, config: Phi3Config, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.config = config self.dtype = dtype self.param_dtype = param_dtype self.precision = precision attn_block = Phi3Attention mlp_block = Phi3MLP attn_block, mlp_block = auto_remat( attn_block, mlp_block, policy=config.gradient_checkpointing, ) self.self_attn = attn_block( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.mlp = mlp_block( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.input_layernorm = RMSNorm( dim=config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.resid_attn_dropout = nn.Dropout( self.config.resid_pdrop, rngs=rngs, ) self.resid_mlp_dropout = nn.Dropout( self.config.resid_pdrop, rngs=rngs, ) self.post_attention_layernorm = RMSNorm( dim=self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) def __call__( self, hidden_states: chex.Array, attention_mask: chex.Array, position_ids: chex.Array, causal_mask: chex.Array, cache_view: tp.Optional[TransformerCacheView] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: bool = False, fcm_mask: tp.Optional[chex.Array] = None, frequencies: tp.Optional[chex.Array] = None, ): """ Forward pass of the module block. Args: hidden_states (chex.Array): Input hidden states. attention_mask (chex.Array): Mask to apply on the attention scores. position_ids (chex.Array): Position indices for the tokens. causal_mask (chex.Array): Causal mask for ensuring autoregressive behavior. segment_ids (tp.Optional[chex.Array]): Segment IDs for segment-based attention (optional). deterministic (bool): If True, disables dropout for deterministic behavior. init_cache (bool): If True, initializes cache for caching keys and values. output_attentions (bool): If True, outputs attention weights alongside the hidden states. fcm_mask (tp.Optional[chex.Array]): fcm mask to be combined with attn mask and causal mask. Returns: tp.Tuple[chex.Array, chex.Array]: A tuple containing the attention output and the attention weights. """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attn_out = self.self_attn( hidden_states, attention_mask, position_ids, causal_mask, cache_view, segment_ids, output_attentions, fcm_mask, frequencies, ) attn_outputs, self_attn_weights = ( (attn_out[0], attn_out[1]) if len(attn_out) == 2 else (attn_out[0], None) ) hidden_states = self.resid_attn_dropout(attn_outputs) + residual residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) if self.config.use_scan_mlp: feed_forward_hidden_states = block_wise_ffn( self.mlp, hidden_states, self.config.scan_mlp_chunk_size, ) else: feed_forward_hidden_states = self.mlp(hidden_states) hidden_states = residual + self.resid_mlp_dropout(feed_forward_hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) return outputs
[docs]@register_module( "base-module", config=Phi3Config, model_type="phi3", embedding_layer_names=["embed_tokens"], ) class Phi3Model(EasyDeLBaseModule): def __init__( self, config: Phi3Config, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embed( config.vocab_size, config.hidden_size, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.embed_dropout = nn.Dropout(config.embd_pdrop) self.layers = [ Phi3DecoderLayer( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) for idx in range(self.config.num_hidden_layers) ] self.norm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, param_dtype=param_dtype, ) @functools.cached_property def frequencies(self): return self.config.get_basic_frequencies( head_size=self.config.hidden_size // self.config.num_attention_heads, base=self.config.rope_theta, ) def __call__( self, input_ids: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, past_key_values: tp.Optional[TransformerCache] = None, return_dict: bool = True, ) -> tp.Union[FlaxBaseModelOutput, tp.Tuple]: """ Forward pass through the Phi3 module. Args: input_ids (chex.Array): Input tensor containing token IDs. attention_mask (chex.Array): Mask for attention. position_ids (chex.Array): Positional indices. segment_ids (tp.Optional[chex.Array]): Segment IDs for different input parts. inputs_embeds (tp.Optional[chex.Array]): Embedded input tensor. output_attentions (tp.Optional[bool]): If True, output attention weights. output_hidden_states (tp.Optional[bool]): If True, output hidden states. init_cache (bool): If True, initialize cache for decoding. deterministic (bool): If True, disable dropout. return_dict (bool): If True, return a dictionary of outputs. Returns: FlaxBaseModelOutput | tp.Tuple: Model output, either as a named tuple or a standard tuple. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids.astype("i4")) batch_size, sequence_length, _ = inputs_embeds.shape all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None assert sequence_length <= self.config.max_position_embeddings, ( f"Maximum Position Embedding Reached ! (Excepted <= {self.config.max_position_embeddings} got {sequence_length})" ) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length), "b1") else: if attention_mask.dtype != jnp.bool: attention_mask = jnp.astype(attention_mask == 1, "b1") if position_ids is None: position_ids = jnp.broadcast_to( jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), (batch_size, sequence_length), ).astype(jnp.int32) if attention_mask.ndim == 2: attention_mask = jnp.expand_dims(attention_mask, (1, 2)) if past_key_values is None: past_key_values = TransformerCache.init_empty(len(self.layers)) hidden_states = inputs_embeds for idx, block in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = block( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, cache_view=past_key_values.views[idx], causal_mask=self.causal_mask, output_attentions=output_attentions, segment_ids=segment_ids, frequencies=self.frequencies, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions += (layer_outputs[1],) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) outputs = (hidden_states, all_hidden_states, all_attentions, past_key_values) if not return_dict: return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions, past_key_values=past_key_values, )
[docs]@register_module( "causal-language-model", config=Phi3Config, model_type="phi3", embedding_layer_names=["embed_tokens"], ) class Phi3ForCausalLM(EasyDeLBaseModule): def __init__( self, config: Phi3Config, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.model = Phi3Model( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.vocab_size = self.config.vocab_size self.lm_head = nn.Linear( config.hidden_size, config.vocab_size, use_bias=False, kernel_init=jax.nn.initializers.normal(config.initializer_range), dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, input_ids: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, past_key_values: tp.Optional[TransformerCache] = None, return_dict: bool = True, ) -> tp.Union[FlaxCausalLMOutput, tp.Tuple]: """ Forward pass through the Phi3 module. Args: input_ids (tp.Optional[chex.Array]): Input tensor containing token IDs. attention_mask (tp.Optional[chex.Array]): Mask for attention. position_ids (tp.Optional[chex.Array]): Positional indices. segment_ids (tp.Optional[chex.Array]): Segment IDs for different input parts. inputs_embeds (tp.Optional[chex.Array]): Embedded input tensor. output_attentions (tp.Optional[bool]): If True, output attention weights. output_hidden_states (tp.Optional[bool]): If True, output hidden states. init_cache (bool): If True, initialize cache for decoding. deterministic (bool): If True, disable dropout. return_dict (bool): If True, return a dictionary of outputs. Returns: FlaxCausalLMOutput | tp.Tuple: Model output, either as a named tuple or a standard tuple. """ outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, inputs_embeds=inputs_embeds, segment_ids=segment_ids, ) hidden_states = outputs.last_hidden_state if self.config.tie_word_embeddings: lm_logits = jax.lax.dot_general( hidden_states, self.model.embed_tokens.embedding.value.T, (((hidden_states.ndim - 1), (0,)), ((), ())), ) else: lm_logits = self.lm_head(hidden_states) if not return_dict: return (lm_logits,) + outputs[0:] return FlaxCausalLMOutput( logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, past_key_values=outputs.past_key_values, )