vSurge Inference Engine Example#
The vSurge component in EasyDeL provides a flexible and efficient inference engine for large language models. It is designed to handle both streaming and non-streaming text generation. This example demonstrates how to use vSurge with the oDrive engine backend, based on a typical inference script.
Import necessary modules#
First, import the required libraries: jax, jax.numpy, transformers.AutoTokenizer, and easydel.
import jax
from jax import numpy as jnp
from transformers import AutoTokenizer
import easydel as ed
Load the model and tokenizer#
Load the pretrained model and tokenizer using EasyDeL’s AutoEasyDeLModelForCausalLM and Hugging Face’s AutoTokenizer. Configure the model with necessary parameters like sharding, data types, and attention mechanisms.
pretrained_model_name_or_path: Specifies the model to load from Hugging Face or a local path.
dtype and param_dtype: Define the data types for computations and model parameters, respectively. jnp.bfloat16 is commonly used for efficiency.
max_length: The maximum sequence length the model can handle.
prefill_length: The maximum length for the initial prompt processing (prefill stage).
partition_axis: An EasyDeL utility for defining sharding configurations across devices.
processor: The tokenizer loaded using AutoTokenizer. Padding side is set to “left” and pad token is set to EOS token for batched inference.
model: The EasyDeL model loaded using AutoEasyDeLModelForCausalLM. - auto_shard_model: Automatically shards the model parameters across available devices. - sharding_axis_dims: Defines the sharding dimensions for the model parameters. (1, 1, -1, 1) is a common configuration. - config_kwargs: Allows passing additional configuration parameters to the model’s underlying configuration object using EasyDeLBaseConfigDict.
freq_max_position_embeddings and mask_max_position_embeddings: Related to rotary embeddings and attention masks, set to max_length.
kv_cache_quantization_method: Specifies the quantization method for the KV cache. ed.EasyDeLQuantizationMethods.NONE means no quantization.
gradient_checkpointing: Controls gradient checkpointing behavior. ed.EasyDeLGradientCheckPointers.NONE disables it.
attn_mechanism: Specifies the attention mechanism to use. ed.AttentionMechanisms.PAGED_ATTENTION is crucial for efficient KV cache management in vSurge.
quantization_method: Specifies the quantization method for model weights. ed.EasyDeLQuantizationMethods.NONE means no quantization.
precision: Controls the precision of computations. jax.lax.Precision.Precision.DEFAULT uses the default precision for the chosen dtype.
pretrained_model_name_or_path = (
"Qwen/Qwen3-8B"
)
dtype = param_dtype = jnp.bfloat16
max_length = 8192
prefill_length = 4096
partition_axis = ed.PartitionAxis()
processor = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
processor.padding_side = "left"
processor.pad_token_id = processor.eos_token_id
model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
auto_shard_model=True,
sharding_axis_dims=(1, 1, -1, 1),
config_kwargs=ed.EasyDeLBaseConfigDict(
freq_max_position_embeddings=max_length,
mask_max_position_embeddings=max_length,
kv_cache_quantization_method=ed.EasyDeLQuantizationMethods.NONE,
gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NONE,
attn_mechanism=ed.AttentionMechanisms.PAGED_ATTENTION,
),
quantization_method=ed.EasyDeLQuantizationMethods.NONE,
param_dtype=param_dtype,
dtype=dtype,
partition_axis=partition_axis,
precision=jax.lax.Precision.DEFAULT,
)
Create the vSurge instance#
Instantiate the vSurge engine using the create_odriver class method. This method sets up vSurge to use the oDrive engine backend. Provide the loaded model, processor, and configuration parameters.
model: The loaded EasyDeL model.
processor: The loaded tokenizer.
max_prefill_length: The maximum length for the prefill stage.
prefill_lengths: A list of possible prefill lengths the engine should be prepared to handle efficiently.
page_size: The size of KV cache pages. This is a key parameter for Paged Attention.
hbm_utilization: The target utilization of High Bandwidth Memory (HBM) for the KV cache.
max_concurrent_prefill: The maximum number of prefill requests that can be processed concurrently.
max_concurrent_decodes: The maximum number of decode requests (token generation) that can be processed concurrently.
seed: A random seed for reproducibility.
page_size = 128
hbm_utilization = 0.875
max_concurrent_decodes = 64
max_concurrent_prefill = 64 # Often set equal to max_concurrent_decodes
surge = ed.vSurge.create_odriver(
model=model,
processor=processor,
max_prefill_length=prefill_length,
prefill_lengths=[prefill_length],
page_size=page_size,
hbm_utilization=hbm_utilization,
max_concurrent_prefill=max_concurrent_prefill,
max_concurrent_decodes=max_concurrent_decodes,
seed=877,
)
Start and Compile the Engine#
Before performing inference, the vSurge engine needs to be started and compiled. The start() method initializes the engine, and the compile() method compiles the necessary JAX functions for efficient execution.
surge.start()
surge.compile()
Non-Streaming Generation#
For non-streaming generation, call the generate method with stream=False. Provide a list of prompts and corresponding sampling parameters. The method will return a list of final results once generation is complete for all prompts.
prompts: A list of input strings for which to generate text.
sampling_params: A list of ed.SamplingParams objects, one for each prompt, specifying parameters like: - max_tokens: The maximum number of tokens to generate. - temperature: Controls the randomness of the output. Higher values mean more randomness. - top_p: The cumulative probability threshold for nucleus sampling.
non_streaming_prompts = [
"USER:What is the capital of France?\nASSISTANT:",
"USER:Explain the concept of recursion\nASSISTANT:",
]
non_streaming_sampling_params = [
ed.SamplingParams(max_tokens=30, temperature=0.1),
ed.SamplingParams(max_tokens=80, temperature=0.6, top_p=0.9),
]
# For non-streaming, the generate method returns a list of final results
# Note: generate is an async method, so it should be awaited in an async context.
import asyncio
async def run_non_streaming():
final_results = await surge.generate(
prompts=non_streaming_prompts,
sampling_params=non_streaming_sampling_params,
stream=False,
)
# final_results is a list of ReturnSample objects (one per prompt)
for i, result in enumerate(final_results):
print(f"Non-Streaming Result for Prompt {i + 1}:")
print(f" Generated Text: {result.text}")
print(f" Tokens per second: {result.tokens_per_second}")
# To run this in a script:
# asyncio.run(run_non_streaming())
Iterate through the results to access the generated text and other information like tokens per second.
Streaming Generation#
The vSurge engine also supports streaming generation, which is useful for applications that need to display tokens as they are generated (e.g., chatbots). To perform streaming inference, you call the generate method with stream=True. This method returns an asynchronous iterator that yields ReturnSample objects as tokens are generated for each prompt.
You would typically iterate through this asynchronous iterator to process the incoming tokens.
streaming_prompts = [
"USER:Tell me a short story about a cat.\nASSISTANT:",
"USER:Describe the process of photosynthesis.\nASSISTANT:",
]
streaming_sampling_params = [
ed.SamplingParams(max_tokens=50, temperature=0.7),
ed.SamplingParams(max_tokens=100, temperature=0.5),
]
async def run_streaming():
# generate with stream=True returns an async iterator
async for request_output in surge.generate(
prompts=streaming_prompts,
sampling_params=streaming_sampling_params,
stream=True,
):
# request_output is a list of ReturnSample objects, one for each prompt
for i, sample in enumerate(request_output):
# The text field in streaming provides the cumulative generated text so far
print(f"Streaming Update for Prompt {i + 1}: {sample.text}")
# Add a small delay to simulate processing time if needed
# await asyncio.sleep(0.01)
# To run this in a script:
# asyncio.run(run_streaming())
Stop the Engine#
After completing inference, stop the vSurge engine to release resources.
surge.stop()