easydel.modules.xerxes2.modeling_xerxes2_flax#
- class easydel.modules.xerxes2.modeling_xerxes2_flax.Xerxes2Attention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModule
- class easydel.modules.xerxes2.modeling_xerxes2_flax.Xerxes2DecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.xerxes2.modeling_xerxes2_flax.Xerxes2ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- create_cache_metadata(batch_size: int, max_length: int, pad_token_id: int | None = None)[source]#
Creates the metadata required for initializing a standard (non-paged) KV Cache.
This method gathers parameters like layer count, head dimensions, and determines the appropriate padding token ID to instantiate and return a TransformerCacheMetaData object suitable for a standard sequential KV cache.
- Parameters
batch_size (int) – The batch size for which the cache is being configured.
max_length (int) – The maximum sequence length the cache needs to support.
pad_token_id (int | None) – The ID of the padding token. If None, it attempts to find it from self.generation_config or self.config, defaulting to 0.
- Returns
An initialized metadata object for a standard KV cache.
- Return type
- init_cache(batch_size: int, max_length: int, starts: int | None = None, shardings: dict | None = None, pad_token_id: int | None = None)[source]#
Initializes and returns a standard (non-paged) Key-Value cache.
This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the model’s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.
- Parameters
batch_size (int) – The batch size for the cache.
max_length (int) – The maximum sequence length the cache needs to support.
starts (int | None) – Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).
shardings (dict | None) – Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).
pad_token_id (int | None) – The ID of the padding token. If None, it’s inferred.
- Returns
An initialized standard TransformerCache object.
- Return type
- class easydel.modules.xerxes2.modeling_xerxes2_flax.Xerxes2MLP(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.xerxes2.modeling_xerxes2_flax.Xerxes2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule