easydel.modules.llama.modeling_llama_flax#

class easydel.modules.llama.modeling_llama_flax.LlamaAttention(*args: Any, **kwargs: Any)[source]#

Bases: FlaxAttentionModule

class easydel.modules.llama.modeling_llama_flax.LlamaDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.llama.modeling_llama_flax.LlamaForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

config#

Configuration for the attention module.

Type

LlamaConfig

dtype#

Data type for computations (default is jnp.bfloat16).

Type

jnp.dtype

param_dtype#

Data type for parameters (default is jnp.bfloat16).

Type

jnp.dtype

precision#

Precision setting for JAX operations (default is “fastest”).

Type

tp.Optional[tp.Union[str, jax.lax.Precision]]

class easydel.modules.llama.modeling_llama_flax.LlamaForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.modules.llama.modeling_llama_flax.LlamaMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.llama.modeling_llama_flax.LlamaModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule