easydel.inference.vinference._fn#

Module for text generation pipeline using JAX/Flax.

easydel.inference.vinference._fn.basic_generation_first_iter_fn(graphdef: object, graphstate: dict, graphother, state: SampleState, generation_config: vInferenceConfig) SampleState[source]#

Compiled function for performing the initial generation step.

This function takes the graphdef, parameters, input IDs, attention mask, position IDs, generation configuration, and a random number generator key as input. It initializes the generation state and performs the first sampling step.

Returns

The initial generation state after the first sampling step.

Return type

SampleState

easydel.inference.vinference._fn.basic_generation_iter_fn(graphdef: object, graphstate: dict, graphother, state: SampleState, generation_config: vInferenceConfig, loop_max_tokens: int) SampleState[source]#

Compiled function for performing interval generation steps.

This function takes the graphdef, parameters, current generation state, generation configuration, maximum number of tokens for the loop, and the starting length as input. It continues the generation process until the termination condition is met.

Returns

The updated generation state after the interval generation steps.

Return type

SampleState

easydel.inference.vinference._fn.get_compiled_funcs(batch_size: int, input_tokens_length: int, id: str, safe: bool = True, fn1: Optional[Callable] = None, fn2: Optional[Callable] = None)[source]#

Retrieves compiled generation functions from a cache.

Parameters
  • batch_size – The batch size.

  • input_tokens_length – The length of the input tokens.

  • id – A unique identifier for the compilation.

Returns

A tuple containing the compiled generate and

interval generate functions, or (None, None) if not found in the cache.

Return type

Tuple[Callable, Callable]

easydel.inference.vinference._fn.measure_flops(func, *args, **kwargs)[source]#
easydel.inference.vinference._fn.put_compiled_funcs(compiled_generate_func, compiled_interval_func, batch_size, input_tokens_length, id)[source]#

Stores compiled generation functions in a cache.

Parameters
  • compiled_generate_func – The compiled generate function.

  • compiled_interval_func – The compiled interval generate function.

  • batch_size – The batch size.

  • input_tokens_length – The length of the input tokens.

  • id – A unique identifier for the compilation.