easydel.modules.llama.modeling_llama_flax#

class easydel.modules.llama.modeling_llama_flax.LlamaAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

class easydel.modules.llama.modeling_llama_flax.LlamaDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.llama.modeling_llama_flax.LlamaForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Llama model with a language modeling head for causal language modeling tasks.

This model is a transformer-based language model with causal attention masks applied to perform autoregressive language generation.

config#

Configuration for the model.

Type

LlamaConfig

dtype#

Data type for computations (default is jnp.float32).

Type

jnp.dtype

param_dtype#

Data type for parameters (default is jnp.float32).

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

tp.Optional[tp.Union[str, jax.lax.Precision]]

class easydel.modules.llama.modeling_llama_flax.LlamaForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Llama model for sequence classification tasks.

This class extends the base Llama model by adding a linear classification head to perform sequence classification tasks such as sentiment analysis or text classification.

config#

Configuration for the model.

Type

LlamaConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

class easydel.modules.llama.modeling_llama_flax.LlamaMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.llama.modeling_llama_flax.LlamaModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Llama model implementation.

This implements the Llama language model architecture, utilizing transformer blocks with RMSNorm, rotary position embeddings, and a specific attention mechanism.

config#

Configuration for the model.

Type

LlamaConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.