easydel.modules.mistral.modeling_mistral_flax#

class easydel.modules.mistral.modeling_mistral_flax.MistralAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

class easydel.modules.mistral.modeling_mistral_flax.MistralDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.mistral.modeling_mistral_flax.MistralForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Mistral model with a language modeling head for causal language modeling tasks.

This model is a transformer-based language model with sliding window attention applied to perform autoregressive language generation.

config#

Configuration for the model.

Type

MistralConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

class easydel.modules.mistral.modeling_mistral_flax.MistralForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Mistral model for sequence classification tasks.

This class extends the base Mistral model by adding a linear classification head to perform sequence classification tasks such as sentiment analysis or text classification.

config#

Configuration for the model.

Type

MistralConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

class easydel.modules.mistral.modeling_mistral_flax.MistralMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.mistral.modeling_mistral_flax.MistralModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Mistral model implementation.

This implements the Mistral language model architecture, utilizing transformer blocks with RMSNorm, sliding window attention, and rotary position embeddings.

config#

Configuration for the model.

Type

MistralConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.