Vision-Language Models#
EasyDeL provides comprehensive support for vision-language models (VLMs), enabling you to work with multimodal systems that process both images and text. This page demonstrates how to use various vision-language models with EasyDeL.
Supported Models#
EasyDeL supports a wide range of vision-language models, including:
LLaVA: Large Language and Vision Assistant
Gemma3: Google’s multimodal model with vision capabilities
Aya Vision: A powerful vision-language model from Cohere
CLIP: OpenAI’s Contrastive Language-Image Pre-training model
SigLIP: Google’s Sign and Language Image Pre-training model
Qwen2VL: Qwen’s vision-language model
Basic Usage Pattern#
Most vision-language models in EasyDeL follow a similar pattern:
Load the processor/tokenizer for handling text and images
Initialize the model with appropriate configuration
Create inputs by applying a chat template with images
Use the model directly or through vInference for generation
LLaVA Model Example#
LLaVA (Large Language and Vision Assistant) is a popular open-source vision-language model that connects a vision encoder with a language model.
import easydel as ed
import jax
from jax import numpy as jnp
from transformers import AutoProcessor
# Setup parameters
prefill_length = 2048
max_new_tokens = 1024
max_length = max_new_tokens + prefill_length
model_name = "llava-hf/llava-1.5-7b-hf"
# Load model and processor
processor = AutoProcessor.from_pretrained(model_name)
processor.padding_side = "left"
model = ed.AutoEasyDeLModelForImageTextToText.from_pretrained(
model_name,
auto_shard_model=True,
sharding_axis_dims=(1, 1, -1, 1),
config_kwargs=ed.EasyDeLBaseConfigDict(
freq_max_position_embeddings=max_length,
mask_max_position_embeddings=max_length,
attn_mechanism=ed.AttentionMechanisms.VANILLA,
),
dtype=jnp.bfloat16,
param_dtype=jnp.bfloat16,
)
# Prepare input with image and text
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
},
{"type": "text", "text": "Describe this image in detail."},
],
},
]
# Process inputs
inputs = processor.apply_chat_template(
messages,
return_tensors="jax",
add_generation_prompt=True,
tokenize=True,
return_dict=True,
)
# Initialize inference
inference = ed.vInference(
model=model,
processor_class=processor,
generation_config=ed.vInferenceConfig(
max_new_tokens=max_new_tokens,
sampling_params=ed.SamplingParams(
max_tokens=max_new_tokens,
temperature=0.8,
top_p=0.95,
top_k=10,
),
eos_token_id=model.generation_config.eos_token_id,
streaming_chunks=32,
num_return_sequences=1,
),
)
# Precompile for specific dimensions to optimize performance
inference.precompile(
ed.vInferencePreCompileConfig(
batch_size=1,
prefill_length=prefill_length,
vision_included=True, # Important for vision models
vision_batch_size=1, # Number of images
vision_channels=3, # RGB channels
vision_height=336, # Image height
vision_width=336, # Image width
)
)
# Generate response
for response in inference.generate(**inputs):
pass # Process streaming tokens if needed
# Get the final result
result = processor.batch_decode(
response.sequences[..., response.padded_length:],
skip_special_tokens=True,
)[0]
print(result)
CLIP Image-Text Matching#
CLIP (Contrastive Language-Image Pre-training) can be used for zero-shot image classification, image-text similarity, and more:
import easydel as ed
import jax
from transformers import CLIPProcessor
from PIL import Image
import requests
# Load model and processor
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = ed.AutoEasyDeLModelForZeroShotImageClassification.from_pretrained(
"openai/clip-vit-base-patch32"
)
# Load an image
url = "http://images.cocodataset.org/val2017/000000039769.jpg" # cat image
image = Image.open(requests.get(url, stream=True).raw)
# Process inputs
inputs = processor(
text=[
"a photo of a cat",
"a photo of a dog",
"a photo of a person",
"a photo of a car",
],
images=image,
return_tensors="np",
padding=True,
)
# Get predictions
with model.mesh:
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = jax.nn.softmax(logits_per_image, axis=1)
print(f"Prediction probabilities: {probs[0]}")
# Should show highest probability for "a photo of a cat"
SigLIP Image-Text Matching#
SigLIP (Sign and Language Image Pre-training) is Google’s vision-language model:
import easydel as ed
import jax
from jax import numpy as jnp
from PIL import Image
import requests
from transformers import AutoProcessor
# Load image
image = Image.open(
requests.get(
"http://images.cocodataset.org/val2017/000000039769.jpg",
stream=True
).raw
)
# Load processor and models
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
model = ed.AutoEasyDeLModel.from_pretrained(
"google/siglip-base-patch16-224",
auto_shard_model=True,
sharding_axis_dims=(1, 1, -1, 1),
)
# Prepare inputs
texts = ["a photo of a cat", "a photo of a dog", "a photo of a bird"]
inputs = processor(
text=texts,
images=image,
padding="max_length",
return_tensors="jax"
)
# Get predictions
with model.mesh:
outputs = model(**inputs)
# Process results
probs = jax.nn.sigmoid(outputs.logits_per_image)
for i, text in enumerate(texts):
print(f"{probs[0][i]:.1%} probability that the image is '{text}'")
Aya Vision#
Aya Vision is a powerful open multilingual VLM:
import easydel as ed
import jax
from jax import numpy as jnp
from transformers import AutoProcessor
# Load processor and model
processor = AutoProcessor.from_pretrained("CohereForAI/aya-vision-8b")
processor.padding_side = "left"
model = ed.AutoEasyDeLModelForImageTextToText.from_pretrained(
"CohereForAI/aya-vision-8b",
auto_shard_model=True,
sharding_axis_dims=(1, 1, -1, 1),
config_kwargs=ed.EasyDeLBaseConfigDict(
attn_mechanism=ed.AttentionMechanisms.VANILLA,
),
quantization_method=ed.EasyDeLQuantizationMethods.NF4, # Quantization for efficiency
param_dtype=jnp.float16,
dtype=jnp.float16,
)
# Prepare input with image
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
},
{"type": "text", "text": "Describe this image in detail."},
],
},
]
# Process inputs and generate
inputs = processor.apply_chat_template(
messages,
return_tensors="jax",
add_generation_prompt=True,
tokenize=True,
return_dict=True,
)
inference = ed.vInference(
model=model,
processor_class=processor,
generation_config=ed.vInferenceConfig(
max_new_tokens=1024,
streaming_chunks=32,
),
)
# Generate response
inference.precompile(
ed.vInferencePreCompileConfig(
vision_included=True,
vision_height=364,
vision_width=364,
)
)
result = inference.generate_text(**inputs)
print(result)
Gemma3 Multimodal#
Google’s Gemma3 supports multimodal inputs:
import easydel as ed
import jax
from jax import numpy as jnp
from transformers import AutoProcessor
# Load processor and model
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
processor.padding_side = "left"
model = ed.AutoEasyDeLModelForImageTextToText.from_pretrained(
"google/gemma-3-4b-it",
auto_shard_model=True,
sharding_axis_dims=(1, 1, -1, 1),
config_kwargs=ed.EasyDeLBaseConfigDict(
attn_mechanism=ed.AttentionMechanisms.VANILLA,
),
param_dtype=jnp.bfloat16,
dtype=jnp.float16,
)
# Prepare input with image
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
},
{"type": "text", "text": "Describe this image in detail."},
],
},
]
# Process inputs
inputs = processor.apply_chat_template(
messages,
return_tensors="jax",
add_generation_prompt=True,
tokenize=True,
return_dict=True,
)
# Initialize inference
inference = ed.vInference(
model=model,
processor_class=processor,
generation_config=ed.vInferenceConfig(
max_new_tokens=1024,
sampling_params=ed.SamplingParams(
temperature=0.8,
top_p=0.95,
top_k=10,
),
),
)
# Precompile for specific dimensions
inference.precompile(
ed.vInferencePreCompileConfig(
vision_included=True,
vision_batch_size=1,
vision_height=896,
vision_width=896,
)
)
# Generate and get result
result = inference.generate_text(**inputs)
print(result)
Qwen2VL Model#
Qwen2VL is Alibaba’s vision-language model:
import easydel as ed
import jax
from jax import numpy as jnp
from transformers import AutoProcessor
# Configuration
min_pixels = 256 * 28 * 28
resized_height, resized_width = 420, 420
# Load processor and model
processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
min_pixels=min_pixels,
max_pixels=min_pixels,
resized_height=resized_height,
resized_width=resized_width,
)
processor.padding_side = "left"
model = ed.AutoEasyDeLModelForImageTextToText.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
auto_shard_model=True,
config_kwargs=ed.EasyDeLBaseConfigDict(
attn_mechanism=ed.AttentionMechanisms.VANILLA,
),
param_dtype=jnp.bfloat16,
dtype=jnp.float16,
)
# Prepare conversation with image
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://picsum.photos/seed/picsum/200/300",
"min_pixels": min_pixels,
"max_pixels": min_pixels,
"resized_height": resized_height,
"resized_width": resized_width,
},
{"type": "text", "text": "Describe what you see in this image."},
],
},
]
# For Qwen2VL, special processing is required
from qwen_vl_utils import process_vision_info # Import the utility
image_inputs, video_inputs = process_vision_info(messages)
# Process inputs
inputs = processor(
text=[processor.apply_chat_template(messages, add_generation_prompt=True)],
images=image_inputs,
videos=video_inputs,
max_length=2048,
padding="max_length",
return_tensors="jax",
)
# Initialize inference
inference = ed.vInference(
model=model,
processor_class=processor,
generation_config=ed.vInferenceConfig(
max_new_tokens=128,
sampling_params=ed.SamplingParams(
temperature=0.8,
top_p=0.95,
),
),
)
# Generate response
result = inference.generate_text(**inputs)
print(result)
Advanced Features#
Quantization Options#
To reduce memory footprint without significant quality loss, you can apply quantization:
model = ed.AutoEasyDeLModelForImageTextToText.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
quantization_method=ed.EasyDeLQuantizationMethods.NF4, # Use NF4 quantization
# Other parameters...
)
Supported quantization methods:
NF4: 4-bit quantization for efficient inferenceA8BIT: 8-bit quantizationNONE: No quantization (default)
Memory Optimization with Attention Mechanisms#
For large vision-language models, choose the appropriate attention mechanism:
model = ed.AutoEasyDeLModelForImageTextToText.from_pretrained(
"llava-hf/llava-1.5-13b-hf", # Larger model
config_kwargs=ed.EasyDeLBaseConfigDict(
attn_mechanism=ed.AttentionMechanisms.FLASH_ATTENTION, # More efficient attention
# Other parameters...
),
# Other parameters...
)
Custom Image Dimensions#
When working with different image sizes, make sure to precompile with the correct dimensions:
inference.precompile(
ed.vInferencePreCompileConfig(
vision_included=True,
vision_batch_size=1,
vision_channels=3,
vision_height=512, # Custom height
vision_width=768, # Custom width
)
)
Performance Tips#
Use Quantization: For large VLMs, use
NF4orA8BITquantizationOptimize Attention: Choose
FLASH_ATTENTIONfor GPU orSPLASH_ATTENTIONfor TPUPrecompile: Always precompile with
inference.precompile()for best performanceImage Size: Use the smallest image dimensions that give good results
Batch Processing: Process multiple images together when possible
System Memory: VLMs use more memory than text-only models; adjust your batch size accordingly