easydel.modules.llama.modeling_llama_flax#
- class easydel.modules.llama.modeling_llama_flax.LlamaAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModule
- class easydel.modules.llama.modeling_llama_flax.LlamaDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.llama.modeling_llama_flax.LlamaForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleLlama model with a language modeling head for causal language modeling tasks.
This model is a transformer-based language model with causal attention masks applied to perform autoregressive language generation.
- config#
Configuration for the model.
- Type
- dtype#
Data type for computations (default is jnp.float32).
- Type
jnp.dtype
- param_dtype#
Data type for parameters (default is jnp.float32).
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
tp.Optional[tp.Union[str, jax.lax.Precision]]
- class easydel.modules.llama.modeling_llama_flax.LlamaForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleLlama model for sequence classification tasks.
This class extends the base Llama model by adding a linear classification head to perform sequence classification tasks such as sentiment analysis or text classification.
- config#
Configuration for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- class easydel.modules.llama.modeling_llama_flax.LlamaMLP(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.llama.modeling_llama_flax.LlamaModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleLlama model implementation.
This implements the Llama language model architecture, utilizing transformer blocks with RMSNorm, rotary position embeddings, and a specific attention mechanism.
- config#
Configuration for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.