easydel.modules.rwkv.modeling_rwkv#

class easydel.modules.rwkv.modeling_rwkv.RwkvCausalLMOutput(logits: Union[Array, ndarray, bool, number] = None, state: list[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None)[source]#

Bases: ModelOutput

Output type for RWKV causal language model.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None#
logits: Union[Array, ndarray, bool, number] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

state: list[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.modules.rwkv.modeling_rwkv.RwkvFeedForward(*args: Any, **kwargs: Any)[source]#

Bases: Module

RWKV feedforward network with channel mixing.

class easydel.modules.rwkv.modeling_rwkv.RwkvForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

RWKV model with language modeling head for causal generation.

class easydel.modules.rwkv.modeling_rwkv.RwkvModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

RWKV base model with embedding and transformer blocks.

class easydel.modules.rwkv.modeling_rwkv.RwkvOutput(last_hidden_state: Union[Array, ndarray, bool, number] = None, state: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None)[source]#

Bases: ModelOutput

Output type for RWKV model.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None#
last_hidden_state: Union[Array, ndarray, bool, number] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

state: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], ...] | None = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.modules.rwkv.modeling_rwkv.RwkvSelfAttention(*args: Any, **kwargs: Any)[source]#

Bases: Module

RWKV self-attention mechanism with linear complexity.

class easydel.modules.rwkv.modeling_rwkv.SingleStandRwkvBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

Single RWKV transformer block with attention and feedforward layers.

easydel.modules.rwkv.modeling_rwkv.init_state(hidden_size)[source]#

Create zeroed RWKV recurrent state tensors for a given hidden size.

easydel.modules.rwkv.modeling_rwkv.rwkv_linear_attention(time_decay, time_first, key, value, state=None, return_state=False)[source]#

Compute RWKV linear attention update with optional recurrent state.