easydel.modules.llama.modeling_llama_flax#
- class easydel.modules.llama.modeling_llama_flax.LlamaAttention(*args: Any, **kwargs: Any)[source]#
Bases:
FlaxAttentionModule
- class easydel.modules.llama.modeling_llama_flax.LlamaDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.llama.modeling_llama_flax.LlamaForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- config#
Configuration for the attention module.
- Type
- dtype#
Data type for computations (default is jnp.bfloat16).
- Type
jnp.dtype
- param_dtype#
Data type for parameters (default is jnp.bfloat16).
- Type
jnp.dtype
- precision#
Precision setting for JAX operations (default is “fastest”).
- Type
tp.Optional[tp.Union[str, jax.lax.Precision]]
- class easydel.modules.llama.modeling_llama_flax.LlamaForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.modules.llama.modeling_llama_flax.LlamaMLP(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.llama.modeling_llama_flax.LlamaModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule