easydel.modules.mamba.modeling_mamba_flax#

class easydel.modules.mamba.modeling_mamba_flax.Lambda(*args: Any, **kwargs: Any)[source]#

Bases: Module

fn: Callable#
class easydel.modules.mamba.modeling_mamba_flax.MambaBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.mamba.modeling_mamba_flax.MambaCausalLMOutput(last_hidden_state: Union[Array, ndarray, bool, number] = None, hidden_states: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None, attentions: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None, past_key_values: Optional[Dict[str, Union[Array, ndarray, bool, number]]] = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: FlaxBaseModelOutput

cache: Optional[MambaCache] = None#
hidden_states: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None#
logits: Union[Array, ndarray, bool, number] = None#
replace(**kwargs)#
class easydel.modules.mamba.modeling_mamba_flax.MambaConv1D(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.mamba.modeling_mamba_flax.MambaForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

init_cache(batch_size: int, max_length: int)[source]#
prepare_inputs_for_generation(input_ids, max_length, **kwargs)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights token_type_ids: tp.Optional[chex.Array]: TokenTypeIds

Returns

A dictionary of the past_key_values, attention_mask and position ids

update_inputs_for_generation(outputs: MambaOutput, model_kwargs: Dict[str, Any], **kwargs) Dict[str, Any][source]#
class easydel.modules.mamba.modeling_mamba_flax.MambaMixer(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.mamba.modeling_mamba_flax.MambaModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

init_cache(batch_size: int, max_length: int)[source]#
class easydel.modules.mamba.modeling_mamba_flax.MambaOutput(last_hidden_state: Union[Array, ndarray, bool, number] = None, hidden_states: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None, attentions: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None, past_key_values: Optional[Dict[str, Union[Array, ndarray, bool, number]]] = None, loss: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: FlaxBaseModelOutput

cache: Optional[MambaCache] = None#
hidden_states: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None#
last_hidden_state: Union[Array, ndarray, bool, number] = None#
replace(**kwargs)#
easydel.modules.mamba.modeling_mamba_flax.create_tuple_parser(n: int) Callable[[Union[_T, Sequence[_T]]], tuple[_T, ...]][source]#
easydel.modules.mamba.modeling_mamba_flax.init_to_value(x, dtype)[source]#