easydel.inference.utils#
- class easydel.inference.utils.SampleState(current_length: Union[Array, NamedSharding], sequences: Union[Array, NamedSharding], running_token: Union[Array, NamedSharding], is_sequence_finished: Union[Array, NamedSharding], prng_key: Union[PRNGKey, NamedSharding], model_kwargs: Union[Dict[str, Array], NamedSharding], generate_func_flops: Optional[float] = -inf, interval_func_flops: Optional[float] = -inf, tokens_pre_second: Optional[float] = -inf, generated_tokens: Optional[int] = 0, padded_length: Optional[int] = 0)[source]#
Bases:
MappingData class representing the state of the sampling process.
- current_length: Union[Array, NamedSharding]#
- from_tuple()#
- generate_func_flops: Optional[float] = -inf#
- generated_tokens: Optional[int] = 0#
- interval_func_flops: Optional[float] = -inf#
- is_sequence_finished: Union[Array, NamedSharding]#
- items() a set-like object providing a view on D's items#
- keys() a set-like object providing a view on D's keys#
- model_kwargs: Union[Dict[str, Array], NamedSharding]#
- padded_length: Optional[int] = 0#
- prng_key: Union[PRNGKey, NamedSharding]#
- replace(**kwargs)#
- running_token: Union[Array, NamedSharding]#
- sequences: Union[Array, NamedSharding]#
- to_tuple()#
- tokens_pre_second: Optional[float] = -inf#
- values() an object providing a view on D's values#
- easydel.inference.utils.compile_function(func, func_input_args, func_input_kwargs, mesh=None, in_shardings=None, out_shardings=None, static_argnums=None, donate_argnums=None)[source]#
Compiles a JAX function with optional sharding and mesh configuration.
- Parameters
func – The JAX function to compile.
func_input_args – Input arguments for the function.
func_input_kwargs – Input keyword arguments for the function.
mesh – tp.Optional JAX mesh for distributed execution.
in_shardings – tp.Optional input sharding specifications.
out_shardings – tp.Optional output sharding specifications.
static_argnums – Indices of static arguments.
donate_argnums – Indices of arguments to donate.
- Returns
Compiled JAX function.
- easydel.inference.utils.create_sampling_step(logits_processor: FlaxLogitsProcessorList, logits_warper: FlaxLogitsProcessorList, eos_token_id: Array, pad_token_id: Array, do_sample: bool = True)[source]#
- easydel.inference.utils.lower_function(func, func_input_args, func_input_kwargs, mesh=None, in_shardings=None, out_shardings=None, static_argnums=None, donate_argnums=None)[source]#
lower a JAX function with optional sharding and mesh configuration.
- Parameters
func – The JAX function to compile.
func_input_args – Input arguments for the function.
func_input_kwargs – Input keyword arguments for the function.
mesh – tp.Optional JAX mesh for distributed execution.
in_shardings – tp.Optional input sharding specifications.
out_shardings – tp.Optional output sharding specifications.
static_argnums – Indices of static arguments.
donate_argnums – Indices of arguments to donate.
- Returns
lowered JAX function.
- class easydel.inference.utils.vInferenceConfig(max_new_tokens: int = 64, min_length: Optional[int] = None, streaming_chunks: int = 16, temperature: float = 0.0, top_p: float = 0.95, top_k: int = 50, do_sample: bool = True, no_repeat_ngram_size: Optional[int] = None, num_return_sequences: Union[int, Dict[int, int], NoneType] = 1, suppress_tokens: Optional[list] = None, forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, bos_token_id: Optional[int] = None, eos_token_id: Union[int, List[int], NoneType] = None, partition_rules: Optional[Tuple[Tuple[str, Any]]] = None, partition_axis: Optional[eformer.escale.partition.constraints.PartitionAxis] = None, _loop_rows: Optional[int] = None)[source]#
Bases:
object- bos_token_id: Optional[int] = None#
- do_sample: bool = True#
- eos_token_id: Optional[Union[int, List[int]]] = None#
- forced_bos_token_id: Optional[int] = None#
- forced_eos_token_id: Optional[int] = None#
- max_new_tokens: int = 64#
- min_length: Optional[int] = None#
- no_repeat_ngram_size: Optional[int] = None#
- num_return_sequences: Optional[Union[int, Dict[int, int]]] = 1#
- pad_token_id: Optional[int] = None#
- partition_axis: Optional[PartitionAxis] = None#
- partition_rules: Optional[Tuple[Tuple[str, Any]]] = None#
- streaming_chunks: int = 16#
- suppress_tokens: Optional[list] = None#
- temperature: float = 0.0#
- top_k: int = 50#
- top_p: float = 0.95#