Source code for easydel.layers.caching._abstracts
# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as tp
from abc import ABC, abstractmethod
from eformer.pytree import auto_pytree
if tp.TYPE_CHECKING:
from jax.sharding import Mesh, PartitionSpec
from easydel.layers.quantization.quantizers import EasyQuantizer
else:
EasyQuantizer = object
PartitionSpec = tp.Any
Mesh = tp.Any
[docs]class BaseCacheView(ABC):
"""
Abstract base class for a single cache view (typically per layer).
Responsible for:
- Storing cached key/value states
- Tracking current cache position
- Updating cache with new states
"""
metadata: BaseCacheMetadata
layer_index: tp.Optional[int]
[docs] @classmethod
@abstractmethod
def init(cls, metadata: BaseCacheMetadata, *args, **kwargs) -> BaseCacheView:
"""
Initialize a new cache view instance.
Args:
metadata: Configuration metadata for the cache
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
Initialized cache view instance
"""
pass
[docs] @abstractmethod
def concatenate_to_cache(self, *args, **kwargs) -> tp.Any:
"""
Update cache with new states.
Args:
*args: Typically includes new tensors
**kwargs: Additional parameters for cache update
Returns:
Tuple containing:
- anything
"""
pass
[docs]class BaseCache(ABC):
"""
Abstract base class for the main cache container.
Manages a sequence of cache views (typically one per layer) and provides
initialization methods.
"""
views: tp.Sequence[tp.Optional[BaseCacheView]]
[docs] @classmethod
@abstractmethod
def init_cache(
cls,
metadata: BaseCacheMetadata,
*args,
**kwargs,
) -> BaseCache:
"""
Initialize a complete cache with views for all layers.
Args:
metadata: Configuration metadata
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
Fully initialized cache instance
"""
pass
[docs] @classmethod
@abstractmethod
def init_empty(cls, *args, **kwargs) -> BaseCache:
"""
Initialize an empty cache container.
Args:
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
Cache instance with uninitialized views
"""
pass
def __getitem__(self, index):
"""
Enable indexing to access cache views.
Args:
index: Index of the cache view to retrieve
Returns:
The cache view at the specified index
"""
return self.views[index]
def __setitem__(self, index, value):
"""
Enable item assignment to update cache views.
Args:
index: Index of the cache view to update
value: New cache view to assign
"""
self.views[index] = value
def __len__(self) -> int:
"""
Returns the number of cache views.
Returns:
The number of items in the `views` sequence.
Raises:
AttributeError: If `self.views` has not been initialized by a subclass.
"""
if not hasattr(self, "views"):
raise AttributeError(
"The 'views' attribute has not been initialized. "
"Ensure a concrete subclass initializes it."
)
return len(self.views)