easydel.inference.utils#

class easydel.inference.utils.SampleState(current_length: Union[Array, NamedSharding], sequences: Union[Array, NamedSharding], running_token: Union[Array, NamedSharding], is_sequence_finished: Union[Array, NamedSharding], prng_key: Union[PRNGKey, NamedSharding], model_kwargs: Union[Dict[str, Array], NamedSharding], generate_func_flops: Optional[float] = -inf, interval_func_flops: Optional[float] = -inf, tokens_pre_second: Optional[float] = -inf, generated_tokens: Optional[int] = 0, padded_length: Optional[int] = 0)[source]#

Bases: Mapping

Data class representing the state of the sampling process.

current_length: Union[Array, NamedSharding]#
from_tuple()#
generate_func_flops: Optional[float] = -inf#
generated_tokens: Optional[int] = 0#
interval_func_flops: Optional[float] = -inf#
is_sequence_finished: Union[Array, NamedSharding]#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
model_kwargs: Union[Dict[str, Array], NamedSharding]#
padded_length: Optional[int] = 0#
prng_key: Union[PRNGKey, NamedSharding]#
replace(**kwargs)#
running_token: Union[Array, NamedSharding]#
sequences: Union[Array, NamedSharding]#
to_tuple()#
tokens_pre_second: Optional[float] = -inf#
values() an object providing a view on D's values#
easydel.inference.utils.compile_function(func, func_input_args, func_input_kwargs, mesh=None, in_shardings=None, out_shardings=None, static_argnums=None, donate_argnums=None)[source]#

Compiles a JAX function with optional sharding and mesh configuration.

Parameters
  • func – The JAX function to compile.

  • func_input_args – Input arguments for the function.

  • func_input_kwargs – Input keyword arguments for the function.

  • mesh – tp.Optional JAX mesh for distributed execution.

  • in_shardings – tp.Optional input sharding specifications.

  • out_shardings – tp.Optional output sharding specifications.

  • static_argnums – Indices of static arguments.

  • donate_argnums – Indices of arguments to donate.

Returns

Compiled JAX function.

easydel.inference.utils.create_sampling_step(logits_processor: FlaxLogitsProcessorList, logits_warper: FlaxLogitsProcessorList, eos_token_id: Array, pad_token_id: Array, do_sample: bool = True)[source]#
easydel.inference.utils.lower_function(func, func_input_args, func_input_kwargs, mesh=None, in_shardings=None, out_shardings=None, static_argnums=None, donate_argnums=None)[source]#

lower a JAX function with optional sharding and mesh configuration.

Parameters
  • func – The JAX function to compile.

  • func_input_args – Input arguments for the function.

  • func_input_kwargs – Input keyword arguments for the function.

  • mesh – tp.Optional JAX mesh for distributed execution.

  • in_shardings – tp.Optional input sharding specifications.

  • out_shardings – tp.Optional output sharding specifications.

  • static_argnums – Indices of static arguments.

  • donate_argnums – Indices of arguments to donate.

Returns

lowered JAX function.

class easydel.inference.utils.vInferenceConfig(max_new_tokens: int = 64, min_length: Optional[int] = None, streaming_chunks: int = 16, temperature: float = 0.0, top_p: float = 0.95, top_k: int = 50, do_sample: bool = True, no_repeat_ngram_size: Optional[int] = None, num_return_sequences: Union[int, Dict[int, int], NoneType] = 1, suppress_tokens: Optional[list] = None, forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, bos_token_id: Optional[int] = None, eos_token_id: Union[int, List[int], NoneType] = None, partition_rules: Optional[Tuple[Tuple[str, Any]]] = None, partition_axis: Optional[eformer.escale.partition.constraints.PartitionAxis] = None, _loop_rows: Optional[int] = None)[source]#

Bases: object

bos_token_id: Optional[int] = None#
do_sample: bool = True#
eos_token_id: Optional[Union[int, List[int]]] = None#
forced_bos_token_id: Optional[int] = None#
forced_eos_token_id: Optional[int] = None#
get_logits_processor()[source]#
get_logits_warper()[source]#
get_partition_rules(runtime_config: Optional[Tuple[int, int]] = None)[source]#
max_new_tokens: int = 64#
min_length: Optional[int] = None#
no_repeat_ngram_size: Optional[int] = None#
num_return_sequences: Optional[Union[int, Dict[int, int]]] = 1#
pad_token_id: Optional[int] = None#
partition_axis: Optional[PartitionAxis] = None#
partition_rules: Optional[Tuple[Tuple[str, Any]]] = None#
streaming_chunks: int = 16#
suppress_tokens: Optional[list] = None#
temperature: float = 0.0#
top_k: int = 50#
top_p: float = 0.95#
tree_flatten()[source]#
classmethod tree_unflatten(aux, children)[source]#