# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import typing as tp
import chex
import jax
import jax.experimental
import jax.experimental.pallas
import jax.random
from jax import numpy as jnp
from jax import random, sharding
from eformer.jaximus import implicit
from eformer.escale import PartitionAxis
from jax.sharding import PartitionSpec
from flax import nnx as nn
from .logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor,
FlaxNoRepeatNGramLogitsProcessor,
FlaxSuppressTokensLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
hash_fn,
)
[docs]@jax.tree_util.register_pytree_node_class
@dataclasses.dataclass
class vInferenceConfig:
max_new_tokens: int = 64
min_length: tp.Optional[int] = None
streaming_chunks: int = 16
temperature: float = 0.0
top_p: float = 0.95
top_k: int = 50
do_sample: bool = True
no_repeat_ngram_size: tp.Optional[int] = None
num_return_sequences: tp.Optional[tp.Union[int, tp.Dict[int, int]]] = 1
suppress_tokens: tp.Optional[list] = None
forced_bos_token_id: tp.Optional[int] = None
forced_eos_token_id: tp.Optional[int] = None
pad_token_id: tp.Optional[int] = None
bos_token_id: tp.Optional[int] = None
eos_token_id: tp.Optional[tp.Union[int, tp.List[int]]] = None
partition_rules: tp.Optional[tp.Tuple[tp.Tuple[str, tp.Any]]] = None
partition_axis: tp.Optional[PartitionAxis] = None
_loop_rows: tp.Optional[int] = None
[docs] def tree_flatten(self):
return (
self.max_new_tokens,
self.min_length,
self.streaming_chunks,
self.temperature,
self.top_p,
self.top_k,
self.do_sample,
self.no_repeat_ngram_size,
self.num_return_sequences,
self.suppress_tokens,
self.forced_bos_token_id,
self.forced_eos_token_id,
self.pad_token_id,
self.bos_token_id,
self.eos_token_id,
self.partition_rules,
self.partition_axis,
self._loop_rows,
), {}
[docs] @classmethod
def tree_unflatten(cls, aux, children):
return cls(*children)
[docs] def get_partition_rules(
self,
runtime_config: tp.Optional[
tp.Tuple[int, int]
] = None, # in case that someone needs to customize this
):
if self.partition_rules is not None:
return self.partition_rules
assert self.partition_axis is not None, (
"partition axis is required for state sharding"
)
paxis = self.partition_axis
kvps = PartitionSpec(
paxis.batch_axis,
paxis.key_sequence_axis,
paxis.head_axis,
paxis.attention_dim_axis,
)
idps = PartitionSpec(paxis.batch_axis, paxis.sequence_axis)
return (
("(sequences|running_token)", idps),
("model_kwargs/(attention_mask|position_ids)", idps),
# A8BIT
("model_kwargs/past_key_values/views/[0-9]+/(key|value)/(scale|weight)", kvps),
# NF4
("model_kwargs/past_key_values/views/[0-9]+/(key|value)/(packed|absmax)", kvps),
("model_kwargs/past_key_values/views/[0-9]+/(key|value)", kvps),
(".*", PartitionSpec()),
)
def __post_init__(self):
if isinstance(self.max_new_tokens, int):
self._loop_rows = (
self.max_new_tokens + self.streaming_chunks - 1
) // self.streaming_chunks
def __repr__(self):
# fmt:off
string = f"{self.__class__.__name__}(\n"
for k, v in self.__dict__.items():
if not k.startswith("_"):
try:
repr_src = f" {k} : " + v.__str__().replace("\n", "\n ") + "\n"
string += repr_src if len(repr_src) < 500 else f" {k} : " + f"{v.__class__.__name__}(...)" + "\n"
except TypeError: pass #noqa
return string.strip() + "\n)"
# fmt:on
__str__ = __repr__
__hash__ = hash_fn
[docs] def get_logits_warper(self):
warpers = FlaxLogitsProcessorList()
if self.temperature is not None and self.temperature != 1.0:
warpers.append(FlaxTemperatureLogitsWarper(self.temperature))
if self.top_k is not None and self.top_k != 0:
warpers.append(FlaxTopKLogitsWarper(top_k=self.top_k, min_tokens_to_keep=1))
if self.top_p is not None and self.top_p < 1.0:
warpers.append(FlaxTopPLogitsWarper(top_p=self.top_p, min_tokens_to_keep=1))
print(hash(warpers))
if len(warpers) == 0:
return None
return warpers
[docs] def get_logits_processor(self):
processors = FlaxLogitsProcessorList()
eos_id = (
self.eos_token_id[0] if isinstance(self.eos_token_id, list) else self.eos_token_id
)
if (
self.min_length is not None
and self.eos_token_id is not None
and self.min_length > -1
):
processors.append(FlaxMinLengthLogitsProcessor(self.min_length, eos_id))
if self.forced_bos_token_id is not None:
processors.append(FlaxForcedBOSTokenLogitsProcessor(self.forced_bos_token_id))
if self.forced_eos_token_id is not None:
fet = FlaxForcedEOSTokenLogitsProcessor(self.max_length, self.forced_eos_token_id)
processors.append(fet)
if self.suppress_tokens is not None:
processors.append(FlaxSuppressTokensLogitsProcessor(self.suppress_tokens))
if self.no_repeat_ngram_size is not None and self.no_repeat_ngram_size > 0:
processors.append(FlaxNoRepeatNGramLogitsProcessor(self.no_repeat_ngram_size))
if len(processors) == 0:
return None
return processors
[docs]def lower_function(
func,
func_input_args,
func_input_kwargs,
mesh=None,
in_shardings=None,
out_shardings=None,
static_argnums=None,
donate_argnums=None,
):
"""
lower a JAX function with optional sharding and mesh configuration.
Args:
func: The JAX function to compile.
func_input_args: Input arguments for the function.
func_input_kwargs: Input keyword arguments for the function.
mesh: tp.Optional JAX mesh for distributed execution.
in_shardings: tp.Optional input sharding specifications.
out_shardings: tp.Optional output sharding specifications.
static_argnums: Indices of static arguments.
donate_argnums: Indices of arguments to donate.
Returns:
lowered JAX function.
"""
if mesh is None:
return jax.jit(
func,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=static_argnums,
donate_argnums=donate_argnums,
).lower(*func_input_args, **func_input_kwargs)
with mesh:
return jax.jit(
func,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=static_argnums,
donate_argnums=donate_argnums,
).lower(*func_input_args, **func_input_kwargs)
[docs]def compile_function(
func,
func_input_args,
func_input_kwargs,
mesh=None,
in_shardings=None,
out_shardings=None,
static_argnums=None,
donate_argnums=None,
):
"""
Compiles a JAX function with optional sharding and mesh configuration.
Args:
func: The JAX function to compile.
func_input_args: Input arguments for the function.
func_input_kwargs: Input keyword arguments for the function.
mesh: tp.Optional JAX mesh for distributed execution.
in_shardings: tp.Optional input sharding specifications.
out_shardings: tp.Optional output sharding specifications.
static_argnums: Indices of static arguments.
donate_argnums: Indices of arguments to donate.
Returns:
Compiled JAX function.
"""
return lower_function(
func,
func_input_args,
func_input_kwargs,
mesh=mesh,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=static_argnums,
donate_argnums=donate_argnums,
).compile()
[docs]@chex.dataclass
class SampleState:
"""
Data class representing the state of the sampling process.
"""
current_length: tp.Union[jax.Array, sharding.NamedSharding]
sequences: tp.Union[jax.Array, sharding.NamedSharding]
running_token: tp.Union[jax.Array, sharding.NamedSharding]
is_sequence_finished: tp.Union[jax.Array, sharding.NamedSharding]
prng_key: tp.Union[random.PRNGKey, sharding.NamedSharding]
model_kwargs: tp.Union[tp.Dict[str, jax.Array], sharding.NamedSharding]
# vInference Ops
generate_func_flops: tp.Optional[float] = float("-inf")
interval_func_flops: tp.Optional[float] = float("-inf")
tokens_pre_second: tp.Optional[float] = float("-inf")
generated_tokens: tp.Optional[int] = 0
padded_length: tp.Optional[int] = 0
def __repr__(self):
"""
Args:
self: Refer to the instance of the class
Returns:
A string representation of the object
"""
string = f"{self.__class__.__name__}(\n"
for k, v in self.__dict__.items():
if not k.startswith("_"):
try:
repr_src = f" {k} : " + v.__str__().replace("\n", "\n ") + "\n"
string += (
repr_src
if len(repr_src) < 500
else f" {k} : " + f"{v.__class__.__name__}(...)" + "\n"
)
except (TypeError, AttributeError):
pass
return string.strip() + "\n)"
__str__ = __repr__
[docs]def create_sampling_step(
logits_processor: FlaxLogitsProcessorList,
logits_warper: FlaxLogitsProcessorList,
eos_token_id: jax.Array,
pad_token_id: jax.Array,
do_sample: bool = True,
):
@implicit
def sampling_step(graphdef, graphstate, graphother, state: SampleState):
"""
Performs a single sampling step for text generation.
Args:
params: Model parameters.
state (inference_utils.SampleState): The current generation state.
Returns:
inference_utils.SampleState: The updated generation state.
"""
model = nn.merge(graphdef, graphstate, graphother)
model_outputs = model(
input_ids=state.running_token,
return_dict=True,
**state.model_kwargs,
)
logits = model_outputs.logits[:, -1]
if logits_processor is not None:
logits = logits_processor(state.sequences, logits, state.current_length)
if do_sample:
if logits_warper is not None:
logits = logits_warper(logits, logits, state.current_length)
next_token = jax.random.categorical(state.prng_key, logits, axis=-1)
else:
next_token = jnp.argmax(logits, axis=-1)
next_token = (
next_token * ~state.is_sequence_finished
+ pad_token_id * state.is_sequence_finished
)
next_sequence_finished = state.is_sequence_finished | jnp.isin(
next_token,
eos_token_id,
)
next_token = next_token[:, None]
next_sequences = jax.lax.dynamic_update_slice(
state.sequences,
next_token,
(0, state.current_length),
)
next_model_kwargs = model.update_inputs_for_generation(
model_outputs, state.model_kwargs
)
return state.replace(
current_length=state.current_length + 1,
sequences=next_sequences,
running_token=next_token,
is_sequence_finished=next_sequence_finished,
prng_key=jax.random.split(state.prng_key, 2)[0],
model_kwargs=next_model_kwargs,
generated_tokens=state.generated_tokens + 1,
)
return sampling_step