easydel.modules.mamba.modeling_mamba_flax#

class easydel.modules.mamba.modeling_mamba_flax.Lambda(*args: Any, **kwargs: Any)[source]#

Bases: Module

fn: Callable#
class easydel.modules.mamba.modeling_mamba_flax.MambaBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.mamba.modeling_mamba_flax.MambaCausalLMOutput(last_hidden_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number] = None, hidden_states: Optional[Tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]] = None, attentions: Optional[Tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]] = None, past_key_values: Optional[Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]] = None, loss: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, NoneType] = None, logits: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number] = None, cache: Optional[easydel.layers.caching.mamba.mamba_cache.MambaCache] = None)[source]#

Bases: BaseModelOutput

cache: Optional[MambaCache] = None#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None#
logits: Union[Array, ndarray, bool, number] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.modules.mamba.modeling_mamba_flax.MambaConv1D(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.mamba.modeling_mamba_flax.MambaForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

init_cache(batch_size: int, max_length: int, starts: int | None = None, shardings: dict | None = None, pad_token_id: int | None = None)[source]#

Initializes and returns a standard (non-paged) Key-Value cache.

This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the model’s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.

Parameters
  • batch_size (int) – The batch size for the cache.

  • max_length (int) – The maximum sequence length the cache needs to support.

  • starts (int | None) – Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).

  • shardings (dict | None) – Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).

  • pad_token_id (int | None) – The ID of the padding token. If None, it’s inferred.

Returns

An initialized standard TransformerCache object.

Return type

TransformerCache

prepare_inputs_for_generation(input_ids, max_length: int, pad_token_id: int, starts: int | None = None, **kwargs)[source]#

Sets up the initial inputs required for starting autoregressive generation.

This function initializes the Key-Value cache (past_key_values) using init_cache, calculates the initial position_ids based on the input attention_mask (or assumes a contiguous range if no mask is provided), and prepares an extended attention_mask suitable for caching. It ensures inputs are placed on the correct devices/shards.

Parameters
  • input_ids (chex.Array) – The initial sequence of token IDs. Shape (batch_size, seq_length).

  • max_length (int) – The maximum sequence length that the KV cache should support.

  • pad_token_id (int) – The ID used for padding tokens. Used to calculate starts if not provided.

  • starts (int | None) – Optional pre-calculated starting positions (number of leading pads). If None, calculated using compute_prefill_length.

  • shardings (dict | None) – Optional sharding configuration passed to init_cache.

  • attention_mask (tp.Optional[chex.Array]) – An optional mask indicating which tokens should be attended to. Shape (batch_size, seq_length).

  • token_type_ids (tp.Optional[chex.Array]) – Optional segment IDs for models that use them.

Returns

A dictionary containing the prepared inputs, typically including:
  • ”past_key_values”: The initialized KV cache.

  • ”attention_mask”: The extended attention mask for generation.

  • ”position_ids”: The calculated initial position IDs.

  • ”token_type_ids”: (Optional) Prepared token type IDs.

This dictionary is then passed through prepare_inputs_for_call.

Return type

dict

update_inputs_for_generation(outputs: MambaOutput, model_kwargs: Dict[str, Any], **kwargs) Dict[str, Any][source]#

Updates the keyword arguments for the next generation step.

Specifically, it takes the past_key_values from the model_outputs of the current step and updates the model_kwargs with them. It also increments the position_ids by one for the next token prediction.

Parameters
  • model_outputs – The output object from the model’s forward pass in the previous step (should contain a past_key_values attribute).

  • model_kwargs (dict) – The dictionary of keyword arguments used for the model call. This dictionary will be modified in-place or a new one returned.

Returns

The updated model_kwargs dictionary ready for the next generation step.

Return type

dict

class easydel.modules.mamba.modeling_mamba_flax.MambaMixer(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.mamba.modeling_mamba_flax.MambaModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

init_cache(batch_size: int, max_length: int, starts: int | None = None, shardings: dict | None = None, pad_token_id: int | None = None)[source]#

Initializes and returns a standard (non-paged) Key-Value cache.

This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the model’s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.

Parameters
  • batch_size (int) – The batch size for the cache.

  • max_length (int) – The maximum sequence length the cache needs to support.

  • starts (int | None) – Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).

  • shardings (dict | None) – Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).

  • pad_token_id (int | None) – The ID of the padding token. If None, it’s inferred.

Returns

An initialized standard TransformerCache object.

Return type

TransformerCache

class easydel.modules.mamba.modeling_mamba_flax.MambaOutput(last_hidden_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number] = None, hidden_states: Optional[Tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]] = None, attentions: Optional[Tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]] = None, past_key_values: Optional[Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]] = None, loss: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, NoneType] = None, cache: Optional[easydel.layers.caching.mamba.mamba_cache.MambaCache] = None)[source]#

Bases: BaseModelOutput

cache: Optional[MambaCache] = None#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None#
last_hidden_state: Union[Array, ndarray, bool, number] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

easydel.modules.mamba.modeling_mamba_flax.create_tuple_parser(n: int) Callable[[Union[_T, Sequence[_T]]], tuple[_T, ...]][source]#
easydel.modules.mamba.modeling_mamba_flax.init_to_value(x, dtype)[source]#