easydel.inference.vinference._fn#
Module for text generation pipeline using JAX/Flax.
- easydel.inference.vinference._fn.basic_generation_first_iter_fn(graphdef: object, graphstate: dict, graphother, state: SampleState, generation_config: vInferenceConfig) SampleState[source]#
Compiled function for performing the initial generation step.
This function takes the graphdef, parameters, input IDs, attention mask, position IDs, generation configuration, and a random number generator key as input. It initializes the generation state and performs the first sampling step.
- Returns
The initial generation state after the first sampling step.
- Return type
- easydel.inference.vinference._fn.basic_generation_iter_fn(graphdef: object, graphstate: dict, graphother, state: SampleState, generation_config: vInferenceConfig, loop_max_tokens: int) SampleState[source]#
Compiled function for performing interval generation steps.
This function takes the graphdef, parameters, current generation state, generation configuration, maximum number of tokens for the loop, and the starting length as input. It continues the generation process until the termination condition is met.
- Returns
The updated generation state after the interval generation steps.
- Return type
- easydel.inference.vinference._fn.get_compiled_funcs(batch_size: int, input_tokens_length: int, id: str, safe: bool = True, fn1: Optional[Callable] = None, fn2: Optional[Callable] = None)[source]#
Retrieves compiled generation functions from a cache.
- Parameters
batch_size – The batch size.
input_tokens_length – The length of the input tokens.
id – A unique identifier for the compilation.
- Returns
- A tuple containing the compiled generate and
interval generate functions, or (None, None) if not found in the cache.
- Return type
Tuple[Callable, Callable]
- easydel.inference.vinference._fn.put_compiled_funcs(compiled_generate_func, compiled_interval_func, batch_size, input_tokens_length, id)[source]#
Stores compiled generation functions in a cache.
- Parameters
compiled_generate_func – The compiled generate function.
compiled_interval_func – The compiled interval generate function.
batch_size – The batch size.
input_tokens_length – The length of the input tokens.
id – A unique identifier for the compilation.