easydel.modules.llama.modeling_llama#

class easydel.modules.llama.modeling_llama.LlamaAttention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

Multi-head attention layer with RoPE embeddings for Llama models.

class easydel.modules.llama.modeling_llama.LlamaDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Single decoder layer for Llama models.

Combines multi-head attention and feedforward networks with RMS normalization and residual connections.

class easydel.modules.llama.modeling_llama.LlamaForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[LlamaModel, LlamaConfig]

Llama model with a language modeling head for causal language modeling tasks.

This model is a transformer-based language model with causal attention masks applied to perform autoregressive language generation.

config#

Configuration for the model.

Type

LlamaConfig

dtype#

Data type for computations (default is jnp.bfloat16).

Type

jnp.dtype

param_dtype#

Data type for parameters (default is jnp.bfloat16).

Type

jnp.dtype

precision#

Precision setting for JAX operations.

class easydel.modules.llama.modeling_llama.LlamaForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: BaseSequenceClassificationModule[LlamaModel, LlamaConfig]

Llama model for sequence classification tasks.

This class extends the base Llama model by adding a linear classification head to perform sequence classification tasks such as sentiment analysis or text classification.

config#

Configuration for the model.

Type

LlamaConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

class easydel.modules.llama.modeling_llama.LlamaMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Multi-Layer Perceptron module for Llama models.

Implements the feedforward network with SwiGLU activation function for enhanced representation learning in Llama architecture.

class easydel.modules.llama.modeling_llama.LlamaModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Llama model implementation.

This implements the Llama language model architecture, utilizing transformer blocks with RMSNorm, rotary position embeddings, and a specific attention mechanism.

config#

Configuration for the model.

Type

LlamaConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.