Caching#

EasyData provides a multi-layer caching system inspired by Levanter’s TreeCache for efficient data processing.

Overview#

The caching system provides:

  • Memory Cache: Fast LRU cache for recently accessed data

  • Disk Cache: Persistent cache with compression and expiry

  • TreeCacheManager: Hierarchical combination of both layers

  • Dataset Cache: Specialized cache for HuggingFace datasets

TreeCacheManager#

The main caching interface combining memory and disk layers:

from easydel.data import TreeCacheManager

# Create cache manager
cache = TreeCacheManager(
    cache_dir="./cache",
    memory_size=100,      # Max items in memory
    disk_expiry=86400,    # 24 hours
    compression="zstd",   # none, gzip, lz4, zstd
)

# Store data
cache.put("my_key", {"input_ids": [1, 2, 3]})

# Retrieve data
result = cache.get("my_key")
if result:
    data, metadata = result
    print(data)

# Check existence
if cache.contains("my_key"):
    print("Cached!")

# Invalidate
cache.invalidate("my_key")  # Single key
cache.invalidate()          # Clear all

Cache-on-Compute Pattern#

Use get_or_compute for lazy caching:

from easydel.data import TreeCacheManager, CacheMetadata

cache = TreeCacheManager("./cache")

def expensive_tokenization(text):
    # Expensive operation
    return tokenizer(text)["input_ids"]

# Will compute and cache if not exists
result = cache.get_or_compute(
    key="text_123",
    compute_fn=lambda: expensive_tokenization("Hello world"),
    metadata=CacheMetadata(
        source_hash="abc123",
        tokenizer_hash="llama2",
    ),
)

Cache Keys#

Generate consistent cache keys from configuration:

from easydel.data import TreeCacheManager

# From config dictionary
key = TreeCacheManager.compute_key(
    config={
        "tokenizer": "meta-llama/Llama-2-7b",
        "max_length": 2048,
        "data_file": "data.jsonl",
    },
    prefix="tokenized",
)
# "tokenized_a1b2c3d4e5f6"

# With content hash
key = TreeCacheManager.compute_key(
    config={"tokenizer": "llama"},
    prefix="example",
    include_content_hash=True,
    content="Hello world",
)
# "example_a1b2c3d4_e5f6g7h8"

CacheMetadata#

Track cache validity:

from easydel.data import CacheMetadata

metadata = CacheMetadata(
    version="1.0",
    source_hash="abc123",         # Hash of source data
    tokenizer_hash="llama2_7b",   # Hash of tokenizer
    transform_hash="v2",          # Hash of transforms
    num_examples=100000,
    config_hash="xyz789",
    extra={"custom": "data"},
)

# Check validity
if metadata.is_valid_for(config_hash="xyz789", source_hash="abc123"):
    print("Cache is valid")

Memory Cache (LRU)#

Fast in-memory cache with automatic eviction:

from easydel.data import MemoryCache

cache = MemoryCache(max_size=1000)

# Basic operations
cache.put("key1", {"data": "value"})
result = cache.get("key1")  # Returns (data, metadata) or None

# Statistics
stats = cache.stats
print(f"Hit rate: {stats['hit_rate']:.2%}")
print(f"Size: {stats['size']}/{stats['max_size']}")

Disk Cache#

Persistent cache with compression:

from easydel.data import DiskCache

cache = DiskCache(
    cache_dir="./disk_cache",
    compression="zstd",      # Best compression ratio
    expiry_seconds=86400,    # Auto-expire after 24 hours
)

# Operations
cache.put("key", large_data)
result = cache.get("key")  # Returns (data, metadata) or None

# Manual expiry check
if cache.contains("key"):  # Also checks expiry
    result = cache.get("key")

Compression Options#

Compression

Speed

Ratio

Use Case

none

Fastest

1.0x

SSDs, small data

gzip

Slow

~3x

Maximum compatibility

lz4

Fast

~2x

Balanced (recommended)

zstd

Medium

~3.5x

Best ratio

Dataset Cache#

Specialized cache for HuggingFace datasets:

from easydel.data import DatasetCache
from datasets import Dataset

cache = DatasetCache("./dataset_cache")

# Cache a dataset
dataset = Dataset.from_dict({"text": ["hello", "world"]})
cache.put("my_dataset", dataset)

# Load from cache
cached_dataset = cache.get("my_dataset")

# Check and invalidate
if cache.contains("my_dataset"):
    cache.invalidate("my_dataset")

Pipeline Integration#

CacheStageConfig#

Configure caching in pipelines:

from easydel.data import PipelineConfig, CacheStageConfig

config = PipelineConfig(
    datasets=[...],
    cache=CacheStageConfig(
        enabled=True,
        cache_type="hierarchical",  # memory, disk, hierarchical
        cache_dir=".cache/easydel_pipeline",
        memory_cache_size=100,
        disk_cache_expiry=86400,
        compression="lz4",
        hash_fn="combined",  # content, path, combined
    ),
)

Per-Dataset Caching#

from easydel.data import DatasetConfig

config = DatasetConfig(
    data_files="data/*.parquet",
    cache_path="./cache/my_dataset",  # Dataset-specific cache
    cache_enabled=True,
)

Caching Strategies#

1. Tokenization Cache#

Cache tokenized data to avoid re-tokenization:

from easydel.data import TreeCacheManager, CacheMetadata
from transformers import AutoTokenizer

cache = TreeCacheManager("./tokenizer_cache")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b")

def tokenize_with_cache(text: str, text_id: str):
    key = f"tokenized_{text_id}"

    result = cache.get(key)
    if result:
        return result[0]

    tokenized = tokenizer(text, max_length=2048, truncation=True)
    cache.put(
        key,
        dict(tokenized),
        metadata=CacheMetadata(
            tokenizer_hash="llama2_7b",
            num_examples=1,
        ),
    )
    return tokenized

2. Processed Dataset Cache#

Cache entire processed datasets:

from easydel.data import DatasetCache

cache = DatasetCache("./processed_datasets")

def get_processed_dataset(name: str, process_fn):
    # Check cache
    cached = cache.get(name)
    if cached:
        print(f"Loaded {name} from cache")
        return cached

    # Process and cache
    print(f"Processing {name}...")
    dataset = process_fn()
    cache.put(name, dataset)
    return dataset

3. Checkpoint-Aware Caching#

Include training state in cache key:

from easydel.data import TreeCacheManager

cache = TreeCacheManager("./training_cache")

def get_cached_batch(step: int, shard: str, row: int):
    key = f"batch_{shard}_{row}"

    result = cache.get(key)
    if result:
        return result[0]

    # Compute batch
    batch = load_and_process_batch(shard, row)
    cache.put(key, batch)
    return batch

Cache Invalidation#

Automatic Invalidation#

from easydel.data import TreeCacheManager, CacheMetadata

cache = TreeCacheManager("./cache")

def get_with_validation(key: str, expected_config_hash: str):
    result = cache.get_or_compute(
        key=key,
        compute_fn=lambda: compute_data(),
        metadata=CacheMetadata(config_hash=expected_config_hash),
        validate_fn=lambda meta: meta.config_hash == expected_config_hash,
    )
    return result

Manual Invalidation#

# Single key
cache.invalidate("outdated_key")

# Pattern-based (iterate and invalidate)
for key in ["key1", "key2", "key3"]:
    if cache.contains(key):
        cache.invalidate(key)

# Clear all
cache.invalidate()

Expiry-Based#

from easydel.data import DiskCache

# Auto-expire after 1 hour
cache = DiskCache(
    cache_dir="./cache",
    expiry_seconds=3600,
)

# Old entries automatically invalidated on access

Best Practices#

1. Use Appropriate Cache Location#

# Fast SSD for frequently accessed data
cache = TreeCacheManager("/nvme/cache")

# Network storage for shared cache
cache = TreeCacheManager("/shared/cache")

# Memory-only for ephemeral data
from easydel.data import MemoryCache
cache = MemoryCache(max_size=1000)

2. Include Version in Cache Keys#

CACHE_VERSION = "v2"

key = TreeCacheManager.compute_key(
    config={"tokenizer": "llama", "version": CACHE_VERSION},
)

3. Monitor Cache Statistics#

cache = TreeCacheManager("./cache")

# Periodically log stats
stats = cache.stats
print(f"Memory: {stats['memory']}")
print(f"Disk: {stats['disk']}")

# Check hit rates
memory_hits = stats['memory']['hit_rate']
if memory_hits < 0.5:
    print("Consider increasing memory cache size")

4. Clean Up Old Caches#

import shutil
from pathlib import Path

def cleanup_old_caches(cache_dir: str, max_age_days: int = 7):
    import time

    cache_path = Path(cache_dir)
    now = time.time()
    max_age_seconds = max_age_days * 86400

    for item in cache_path.iterdir():
        if item.is_dir():
            age = now - item.stat().st_mtime
            if age > max_age_seconds:
                shutil.rmtree(item)
                print(f"Removed old cache: {item}")

Troubleshooting#

Cache Not Being Used#

# Check cache contains expected key
key = TreeCacheManager.compute_key(config)
print(f"Looking for key: {key}")
print(f"Cache contains: {cache.contains(key)}")

# Verify metadata matches
result = cache.get(key)
if result:
    data, meta = result
    print(f"Cached metadata: {meta}")

Disk Space Issues#

# Set size limit
cache = DiskCache(
    cache_dir="./cache",
    expiry_seconds=3600,  # Short expiry
)

# Or use memory-only
cache = MemoryCache(max_size=100)

Stale Cache#

# Always include version/hash in metadata
metadata = CacheMetadata(
    config_hash=compute_config_hash(),
    tokenizer_hash=compute_tokenizer_hash(),
)

# Use validation function
cache.get_or_compute(
    key=key,
    compute_fn=compute,
    validate_fn=lambda m: m.config_hash == current_hash,
)

Next Steps#