easydel.layers.caching._abstracts#

Abstract base classes for caching systems in EasyDeL.

This module provides the foundational abstract classes that define the interface for all caching implementations in EasyDeL. These abstractions enable different caching strategies (transformer KV-cache, paged attention, state-space models, etc.) to share a common interface while allowing for architecture-specific optimizations.

The caching system is built on a three-tier hierarchy: 1. Metadata classes that store configuration parameters 2. View classes that manage cache state for individual layers 3. Cache classes that orchestrate multiple views across all layers

Key Classes:

BaseCacheMetadata: Abstract base for cache configuration metadata BaseRunTimeMetadata: Abstract base for runtime metadata during computation BaseCacheView: Abstract base for single-layer cache management BaseCache: Abstract base for multi-layer cache orchestration

Design Principles:
  • Functional updates: All cache modifications return new instances

  • PyTree compatibility: All classes use auto_pytree for JAX integration

  • Type safety: Strong typing with generics and protocols

  • Extensibility: Easy to add new caching strategies

Example

To implement a new caching strategy, extend the base classes:

>>> class MyCustomMetadata(BaseCacheMetadata):
...     my_param: int
...
...     @classmethod
...     def create(cls, my_param: int) -> MyCustomMetadata:
...         if my_param <= 0:
...             raise ValueError("my_param must be positive")
...         return cls(my_param=my_param)
>>> class MyCustomView(BaseCacheView):
...     # Implementation details
...     pass
class easydel.layers.caching._abstracts.BaseCache[source]#

Bases: ABC

Abstract base class for multi-layer cache orchestration.

A cache container manages cache views across all layers of a model, providing a unified interface for cache initialization, access, and updates. It acts as the top-level cache object that users interact with.

The cache container pattern enables: - Centralized cache management across layers - Batch operations on all cache views - Consistent initialization and configuration - Easy serialization and checkpointing

Key responsibilities: - Maintain a collection of cache views (one per layer) - Provide factory methods for cache initialization - Enable indexed access to individual layer caches - Support batch operations across all layers - Handle cache serialization and restoration

Design principles: - Composition: Aggregates multiple cache views - Consistency: Ensures all views share compatible configuration - Flexibility: Supports different view types per layer if needed - Convenience: Provides list-like interface for view access

views#

Ordered collection of cache views, one per model layer. None values indicate uninitialized or disabled cache for that layer.

Type

tp.Sequence[BaseCacheView | None]

Note

The class provides default implementations for common operations like indexing and length, but concrete classes must implement the initialization methods.

abstract classmethod init_cache(metadata: BaseCacheMetadata, *args, **kwargs) BaseCache[source]#

Initialize a complete cache with views for all layers.

This factory method creates a fully initialized cache with allocated storage for all layers. It’s the primary way to create a cache for inference, setting up all necessary views with consistent configuration.

The initialization process: 1. Validates metadata configuration 2. Determines resource allocation strategy 3. Creates views for each layer 4. Applies sharding and quantization 5. Returns ready-to-use cache

Parameters
  • metadata (BaseCacheMetadata) – Configuration metadata defining cache dimensions, number of layers, and behavior.

  • *args – Additional positional arguments. Common args include: - mesh: JAX device mesh for distributed execution - dtype: Default dtype for cache tensors - num_layers: Override for number of layers

  • **kwargs – Additional keyword arguments. Common kwargs include: - partition_manager: Sharding configuration - quantizer: Quantization settings - device: Device placement preferences - initial_positions: Starting positions per layer

Returns

A fully initialized cache with views for all layers.

Ready for use in model inference.

Return type

BaseCache

Raises
  • ValueError – If metadata is incompatible with cache type.

  • MemoryError – If insufficient memory for allocation.

  • RuntimeError – If device/sharding configuration fails.

Example

>>> cache = TransformerCache.init_cache(
...     metadata=metadata,
...     mesh=mesh,
...     dtype=jnp.bfloat16,
...     partition_manager=pm
... )
>>> print(f"Initialized cache with {len(cache)} layers")
abstract classmethod init_empty(*args, **kwargs) BaseCache[source]#

Initialize an empty cache container without allocated storage.

Creates a cache structure with placeholder views that can be populated later. This is useful for: - Gradual cache building during training - Memory-efficient initialization - Dynamic cache allocation - Testing and debugging

The empty cache has the correct structure but no allocated tensors, allowing the shape and configuration to be determined dynamically.

Parameters
  • *args – Positional arguments. Common args include: - num_layers: Number of layers to create placeholders for

  • **kwargs – Keyword arguments for future compatibility.

Returns

A cache instance with uninitialized (None) views.

Views must be populated before use.

Return type

BaseCache

Example

>>> cache = TransformerCache.init_empty(num_hidden_layers=12)
>>> # Populate views gradually
>>> for i in range(12):
...     cache[i] = TransformerCacheView.init(...)
views: Sequence[easydel.layers.caching._abstracts.BaseCacheView | None]#
class easydel.layers.caching._abstracts.BaseCacheMetadata[source]#

Bases: ABC

Abstract base class defining the interface for cache metadata.

This class serves as the foundation for all cache metadata implementations. Metadata objects store static configuration that defines how a cache should be initialized and operated. They are immutable after creation and can be safely shared across multiple cache instances.

The metadata pattern separates configuration from state, allowing: - Validation of parameters at creation time - Reuse of configurations across multiple caches - Serialization of cache configurations - Type-safe parameter passing

Concrete implementations must: - Define all required configuration fields as class attributes - Implement the create() factory method with validation - Use the @auto_pytree decorator for JAX compatibility - Ensure all fields are immutable types or frozen dataclasses

All attributes are implementation-specific and should be documented
in the concrete class.

Note

The @auto_pytree decorator makes this class compatible with JAX’s PyTree protocol, enabling it to be used in JAX transformations.

abstract classmethod create(*args, **kwargs) BaseCacheMetadata[source]#

Factory method to create and validate a metadata instance.

This method serves as the primary constructor for metadata objects, providing a centralized location for parameter validation and initialization logic. It should be used instead of direct instantiation to ensure all metadata objects are properly validated.

The factory pattern allows for: - Complex initialization logic beyond simple assignment - Parameter validation before object creation - Derived parameter calculation - Consistent error handling across implementations

Parameters
  • *args – Positional arguments for metadata creation. Implementation-specific parameters.

  • **kwargs – Keyword arguments for metadata creation. Implementation-specific parameters.

Returns

A validated instance of the concrete metadata

implementation. The returned object is immutable and ready for use in cache initialization.

Return type

BaseCacheMetadata

Raises
  • ValueError – If any validation checks fail. Common validations include: - Positive integer checks for dimensions and sizes - Range checks for ratios and percentages - Consistency checks between related parameters - Resource availability checks

  • TypeError – If required parameters are missing or have incorrect types.

Example

>>> metadata = TransformerCacheMetaData.create(
...     batch_size=4,
...     sequence_length=1024,
...     num_hidden_layers=12,
...     num_heads=8,
...     head_dim=64
... )
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching._abstracts.BaseCacheView[source]#

Bases: ABC

Abstract base class for single-layer cache management.

A cache view represents the cache state for a single layer in a neural network. It encapsulates the layer-specific cached data (e.g., key/value pairs for attention, conv/SSM states for Mamba) and provides methods to update this data during inference.

The view pattern allows: - Layer-specific optimization and sharding - Independent cache management per layer - Flexible cache formats for different layer types - Efficient memory layout for each layer’s needs

Key responsibilities: - Store cached states for one model layer - Track the current position/index in the cache - Update cache with new computed states - Apply quantization if configured - Manage memory layout and sharding

Design principles: - Functional updates: Methods return new instances, not modify in-place - Layer isolation: Each view is independent of others - Type flexibility: Support both dense and quantized representations - Sharding aware: Integrate with JAX’s sharding system

metadata#

Configuration metadata for this cache. Shared across all views in the same cache hierarchy.

Type

BaseCacheMetadata

layer_index#

The index of the layer this view represents. None for cache types that don’t have layer structure.

Type

int | None

Note

While marked as ABC, this class doesn’t use @auto_pytree because concrete implementations need to control their PyTree structure.

abstract concatenate_to_cache(*args, **kwargs) Any[source]#

Update the cache with new computed states.

This is the primary method for cache updates during inference. It takes newly computed states (keys, values, hidden states, etc.) and incorporates them into the cache, returning updated tensors and any additional information needed for computation.

The update process typically: 1. Validates input shapes and dtypes 2. Determines update position in cache 3. Applies quantization if configured 4. Updates cache tensors functionally 5. Adjusts masks and indices 6. Returns updated state for next computation

Parameters
  • *args – Positional arguments vary by cache type but commonly include: - key: New key states (for attention caches) - value: New value states (for attention caches) - hidden_states: New hidden states (for SSM caches) - positions: Sequence positions for update

  • **kwargs – Keyword arguments vary by cache type but commonly include: - attention_mask: Mask for valid positions - cache_metadata: Runtime metadata for update - quantizer: Quantization function - causal_mask: Causal attention pattern - mode: Prefill vs generation mode

Returns

Return type varies by implementation but typically includes:
  • Updated cache tensors (functional return)

  • Modified attention masks

  • Updated view instance

  • Additional computation results

Common return patterns: - Transformer: (key_cache, value_cache, mask, updated_view) - Mamba: (updated_view,) - Paged: (updated_view,)

Return type

tp.Any

Note

This method should be functional, returning new tensors rather than modifying existing ones in-place. This ensures compatibility with JAX’s functional programming model.

Example

>>> key_cache, value_cache, mask, new_view = view.concatenate_to_cache(
...     query=query_states,
...     key=key_states,
...     value=value_states,
...     attention_mask=mask
... )
abstract classmethod init(metadata: BaseCacheMetadata, *args, **kwargs) BaseCacheView[source]#

Initialize a new cache view for a single layer.

This factory method creates and initializes a cache view with the appropriate tensor shapes, dtypes, and sharding for a specific layer. It allocates the actual cache storage and sets up initial state.

The initialization process typically: 1. Calculates tensor shapes from metadata 2. Determines sharding strategy for distributed execution 3. Allocates cache tensors with appropriate dtype 4. Applies quantization if configured 5. Sets initial indices and positions

Parameters
  • metadata (BaseCacheMetadata) – Static configuration metadata that defines cache dimensions, dtypes, and behavior.

  • *args – Additional positional arguments. Common args include: - mesh: JAX device mesh for sharding - dtype: JAX dtype for cache tensors - layer_index: Index of the layer - partition_manager: Sharding configuration

  • **kwargs – Additional keyword arguments. Common kwargs include: - quantizer: Quantization configuration - initial_position: Starting cache position - device: Specific device placement

Returns

An initialized cache view ready for use.

The view contains allocated tensors and is configured for the specific layer’s requirements.

Return type

BaseCacheView

Raises
  • ValueError – If metadata parameters are invalid for this view type.

  • MemoryError – If cache allocation fails due to insufficient memory.

Example

>>> view = TransformerCacheView.init(
...     metadata=metadata,
...     mesh=mesh,
...     dtype=jnp.bfloat16,
...     layer_index=0
... )
layer_index: int | None#
metadata: BaseCacheMetadata#
class easydel.layers.caching._abstracts.BaseRunTimeMetadata[source]#

Bases: object

Abstract base class for runtime metadata used during cache operations.

Runtime metadata captures dynamic information that varies during model execution but isn’t part of the permanent cache state. This includes temporary computation state, indices, masks, and other ephemeral data needed during cache updates.

Unlike cache metadata (which is static configuration), runtime metadata: - Changes frequently during model execution - Is specific to individual forward passes - May not need to be preserved between calls - Often contains computation intermediates

Common uses for runtime metadata: - Current sequence positions and offsets - Batch-specific indices and mappings - Temporary buffers for computation - Attention masks and patterns - Memory management information

Concrete implementations should: - Keep fields lightweight and computation-focused - Avoid storing large tensors unless necessary - Use appropriate JAX dtypes for indices (int32) - Document which fields are required vs optional

Note

This class intentionally has no abstract methods, as runtime metadata requirements vary significantly between cache types. Implementations should add fields and methods as needed.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.