easydel.modules.mamba2.modeling_mamba2_flax#

class easydel.modules.mamba2.modeling_mamba2_flax.Conv1D(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.mamba2.modeling_mamba2_flax.Mamba2Block(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.mamba2.modeling_mamba2_flax.Mamba2CausalLMOutput(last_hidden_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number] = None, hidden_states: Optional[Tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]] = None, attentions: Optional[Tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]] = None, past_key_values: Optional[Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]] = None, loss: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, NoneType] = None, logits: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number] = None, cache_params: Optional[easydel.layers.caching.mamba2_cache.Mamba2Cache] = None)[source]#

Bases: FlaxBaseModelOutput

cache_params: Optional[Mamba2Cache] = None#
hidden_states: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None#
logits: Union[Array, ndarray, bool, number] = None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class easydel.modules.mamba2.modeling_mamba2_flax.Mamba2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

init_cache(batch_size: int, max_length: int)[source]#
prepare_inputs_for_generation(input_ids, inputs_embeds=None, cache_params: Optional[Mamba2Cache] = None, cache_position: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, **kwargs)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights

Returns

A dictionary of the past_key_values, attention_mask and position ids

update_inputs_for_generation(model_outputs, model_kwargs)[source]#
class easydel.modules.mamba2.modeling_mamba2_flax.Mamba2Mixer(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.mamba2.modeling_mamba2_flax.Mamba2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.modules.mamba2.modeling_mamba2_flax.Mamba2Output(last_hidden_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number] = None, hidden_states: Optional[Tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]] = None, attentions: Optional[Tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]] = None, past_key_values: Optional[Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]] = None, loss: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, NoneType] = None, cache_params: Optional[easydel.layers.caching.mamba2_cache.Mamba2Cache] = None)[source]#

Bases: FlaxBaseModelOutput

cache_params: Optional[Mamba2Cache] = None#
hidden_states: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None#
last_hidden_state: Union[Array, ndarray, bool, number] = None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class easydel.modules.mamba2.modeling_mamba2_flax.MambaRMSNormGated(*args: Any, **kwargs: Any)[source]#

Bases: Module

easydel.modules.mamba2.modeling_mamba2_flax.create_tuple_parser(n: int) Callable[[Union[_T, Sequence[_T]]], tuple[_T, ...]][source]#
easydel.modules.mamba2.modeling_mamba2_flax.init_to_value(x, dtype)[source]#
easydel.modules.mamba2.modeling_mamba2_flax.pad_tensor_by_size(input_tensor: Array, pad_size: int)[source]#

Padding x tensor with pad_size on the seq_len dim (dim=1)

easydel.modules.mamba2.modeling_mamba2_flax.reshape_into_chunks(input_tensor, pad_size, chunk_size)[source]#

Padding input_tensor with pad_size on the seq_len dim (dim=1) and simultaneously splitting it into chunk sequences.

easydel.modules.mamba2.modeling_mamba2_flax.segment_sum(input_tensor)[source]#

More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.