Multimodal Inference Serving#

EasyDeL provides powerful tools for deploying and serving multimodal models through its vInference API. This page explains how to set up and use the vInference API Server for multimodal models, enabling efficient and scalable deployment of vision-language and audio-language models.

Overview#

The vInference API Server in EasyDeL lets you expose multimodal models through a REST API compatible with the OpenAI API format. This makes it easy to integrate these models into applications, websites, and services.

Setting Up a Multimodal API Server#

Here’s how to set up a multimodal API server with LLaVA as an example:

import easydel as ed
import jax
from jax import numpy as jnp
from transformers import AutoProcessor

# Configuration
prefill_length = 2048
max_new_tokens = 1024
model_name = "llava-hf/llava-1.5-7b-hf"

# Load processor and model
processor = AutoProcessor.from_pretrained(model_name)
processor.padding_side = "left"

model = ed.AutoEasyDeLModelForImageTextToText.from_pretrained(
    model_name,
    auto_shard_model=True,
    sharding_axis_dims=(1, 1, -1, 1),
    config_kwargs=ed.EasyDeLBaseConfigDict(
        freq_max_position_embeddings=prefill_length + max_new_tokens,
        mask_max_position_embeddings=prefill_length + max_new_tokens,
        attn_mechanism=ed.AttentionMechanisms.VANILLA,
    ),
    param_dtype=jnp.bfloat16,
    dtype=jnp.float16,
)

# Create vInference instance
inference = ed.vInference(
    model=model,
    processor_class=processor,
    generation_config=ed.vInferenceConfig(
        max_new_tokens=max_new_tokens,
        streaming_chunks=32,
        sampling_params=ed.SamplingParams(
            max_tokens=max_new_tokens,
            temperature=0.8,
            top_p=0.95,
            top_k=10,
        ),
        eos_token_id=model.generation_config.eos_token_id,
    ),
    inference_name="mmprojector",  # Name for the API endpoint
)

# Precompile for maximum performance
inference.precompile(
    ed.vInferencePreCompileConfig(
        batch_size=1,
        prefill_length=prefill_length,
        vision_included=True,  # Enable vision processing
        vision_batch_size=1,
        vision_channels=3,
        vision_height=336,
        vision_width=336,
    )
)

# Start the API server
ed.vInferenceApiServer(inference, max_workers=1).fire(
    host="0.0.0.0",
    port=8000
)

API Request Format#

The API follows an OpenAI-compatible format. Here’s an example request for the LLaVA model:

{
    "model": "mmprojector",
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
                },
                {
                    "type": "text",
                    "text": "Describe this image in detail."
                }
            ]
        }
    ],
    "temperature": 0.8,
    "top_p": 0.95,
    "max_tokens": 1024,
    "stream": false
}

Streaming Responses#

For interactive applications, you can enable streaming to get tokens as they’re generated:

{
    "model": "mmprojector",
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
                },
                {
                    "type": "text", 
                    "text": "Describe this image in detail."
                }
            ]
        }
    ],
    "temperature": 0.8,
    "top_p": 0.95,
    "max_tokens": 1024,
    "stream": true
}

Serving Multiple Models#

You can serve multiple models simultaneously with the vInferenceApiServer:

# Setup a LLaVA model
llava_inference = ed.vInference(
    model=llava_model,
    processor_class=llava_processor,
    generation_config=llava_config,
    inference_name="llava",  # Name for the first model
)

# Setup a Whisper model
gemma_inference = ed.vInference(
    model=gemma_model,
    processor_class=gemma_processor,
    generation_config=gemma_config,
    inference_name="gemma3",  # Name for the second model
)

# Serve both models on the same server
ed.vInferenceApiServer(
    {
        "llava": llava_inference,
        "gemma3": gemma_inference 
    },  # Dictionary mapping model names to inference engines
    max_workers=4
).fire(
    host="0.0.0.0",
    port=8000
)

# Alternatively, for a single model with automatic naming:
# ed.vInferenceApiServer(whisper_inference).fire(port=8000)

Audio Model API Usage#

For Whisper models, you can make requests with audio files:

curl -X POST "http://localhost:8000/v1/audio/transcriptions" \
  -H "Content-Type: multipart/form-data" \
  -F "file=@speech.mp3" \
  -F "model=whisper" \
  -F "response_format=json" \
  -F "language=en"

The response will be a JSON object with the transcription:

{
  "text": "This is the transcribed text from the audio file."
}

Advanced Configuration#

Load Balancing and Scaling#

For high-traffic applications, configure multiple workers:

ed.vInferenceApiServer(
    inference,
    max_workers=8           # More workers for parallel requests
).fire(
    host="0.0.0.0",
    port=8000
)

Setting Metrics Port for Monitoring#

To enable monitoring with Prometheus metrics:

ed.vInferenceApiServer(inference).fire(
    port=8000,
    metrics_port=8001,      # Port for Prometheus metrics
    log_level="info"        # Logging level (debug, info, warning, error)
)

Secure HTTPS Configuration#

Enable HTTPS with SSL certificates:

ed.vInferenceApiServer(inference).fire(
    port=443,
    ssl_keyfile="/path/to/key.pem",
    ssl_certfile="/path/to/cert.pem"
)

Authentication and Security#

For API key authentication, you’ll need to implement custom middleware with FastAPI. Here’s a simplified example:

from fastapi import Depends, FastAPI, HTTPException, Security
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery

API_KEYS = ["sk-abcdef1234567890", "sk-qwerty0987654321"]
API_KEY_NAME = "Authorization"

api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
api_key_query = APIKeyQuery(name=API_KEY_NAME, auto_error=False)

app = APP  # The FastAPI app used by vInferenceApiServer

async def get_api_key(
    api_key_header: str = Security(api_key_header),
    api_key_query: str = Security(api_key_query),
):
    if api_key_header and api_key_header.startswith("Bearer "):
        api_key = api_key_header.replace("Bearer ", "")
        if api_key in API_KEYS:
            return api_key
    if api_key_query in API_KEYS:
        return api_key_query
    raise HTTPException(
        status_code=403, detail="Could not validate credentials"
    )

# Add the dependency to your chat completions endpoint
app.dependency_overrides[vInferenceApiServer.chat_completions] = lambda: Depends(get_api_key)

Docker Deployment#

For production deployment, use Docker. Example Dockerfile:

FROM nvidia/cuda:12.2.0-devel-ubuntu22.04

# Install dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
    python3 python3-pip && \
    rm -rf /var/lib/apt/lists/*

# Install JAX, EasyDeL and dependencies
RUN pip install --no-cache-dir easydel[all]

# Copy your model code
WORKDIR /app
COPY serve_model.py .

# Expose API port
EXPOSE 8000

# Run the server
CMD ["python3", "serve_model.py"]

Client Integration Examples#

Python Client#

import requests
import base64
from PIL import Image
import io

def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

# Encode the image
image_path = "path/to/image.jpg"
base64_image = encode_image(image_path)

# Prepare the API request
url = "http://localhost:8000/v1/chat/completions"
headers = {
    "Content-Type": "application/json",
    "Authorization": "Bearer sk-your-api-key"  # If API key is enabled
}
payload = {
    "model": "mmprojector",
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": base64_image
                },
                {
                    "type": "text",
                    "text": "What's in this image?"
                }
            ]
        }
    ],
    "temperature": 0.7,
    "max_tokens": 300
}

# Make the API request
response = requests.post(url, headers=headers, json=payload)
print(response.json())

JavaScript/TypeScript Client#

async function queryMultimodalModel() {
    // Convert the image to base64
    const imageFile = document.getElementById('imageInput').files[0];
    const base64Image = await convertToBase64(imageFile);
  
    // Prepare the API request
    const response = await fetch('http://localhost:8000/v1/chat/completions', {
        method: 'POST',
        headers: {
            'Content-Type': 'application/json',
            'Authorization': 'Bearer sk-your-api-key'  // If API key is enabled
        },
        body: JSON.stringify({
            model: 'mmprojector',
            messages: [
                {
                    role: 'user',
                    content: [
                        {
                            type: 'image',
                            image: base64Image
                        },
                        {
                            type: 'text',
                            text: 'What is in this image?'
                        }
                    ]
                }
            ],
            temperature: 0.7,
            max_tokens: 300
        })
    });
  
    const result = await response.json();
    document.getElementById('result').textContent = result.choices[0].message.content;
}

// Helper function to convert file to base64
function convertToBase64(file) {
    return new Promise((resolve, reject) => {
        const reader = new FileReader();
        reader.readAsDataURL(file);
        reader.onload = () => {
            let encoded = reader.result.toString().replace(/^data:(.*,)?/, '');
            resolve(encoded);
        };
        reader.onerror = error => reject(error);
    });
}

Performance Tips#

  1. Precompile with Exact Dimensions: Always precompile with the exact image dimensions you’ll use in production

  2. Quantization: Use appropriate quantization methods for your model size and hardware

  3. Batch Size: Adjust batch_size in precompile settings based on your expected traffic patterns

  4. Streaming Chunks: Fine-tune the streaming_chunks parameter in vInferenceConfig to balance memory usage and compilation overhead

  5. Worker Count: Set max_workers based on your CPU cores and expected concurrent requests