easydel.modules.xerxes2.modeling_xerxes2_flax#

class easydel.modules.xerxes2.modeling_xerxes2_flax.Xerxes2Attention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

class easydel.modules.xerxes2.modeling_xerxes2_flax.Xerxes2DecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.xerxes2.modeling_xerxes2_flax.Xerxes2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

create_cache_metadata(batch_size: int, max_length: int, pad_token_id: int | None = None)[source]#

Creates the metadata required for initializing a standard (non-paged) KV Cache.

This method gathers parameters like layer count, head dimensions, and determines the appropriate padding token ID to instantiate and return a TransformerCacheMetaData object suitable for a standard sequential KV cache.

Parameters
  • batch_size (int) – The batch size for which the cache is being configured.

  • max_length (int) – The maximum sequence length the cache needs to support.

  • pad_token_id (int | None) – The ID of the padding token. If None, it attempts to find it from self.generation_config or self.config, defaulting to 0.

Returns

An initialized metadata object for a standard KV cache.

Return type

TransformerCacheMetaData

init_cache(batch_size: int, max_length: int, starts: int | None = None, shardings: dict | None = None, pad_token_id: int | None = None)[source]#

Initializes and returns a standard (non-paged) Key-Value cache.

This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the model’s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.

Parameters
  • batch_size (int) – The batch size for the cache.

  • max_length (int) – The maximum sequence length the cache needs to support.

  • starts (int | None) – Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).

  • shardings (dict | None) – Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).

  • pad_token_id (int | None) – The ID of the padding token. If None, it’s inferred.

Returns

An initialized standard TransformerCache object.

Return type

TransformerCache

class easydel.modules.xerxes2.modeling_xerxes2_flax.Xerxes2MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.xerxes2.modeling_xerxes2_flax.Xerxes2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

property frequencies: Array#

Returns frequency values from the config.