easydel.modules.arctic.modeling_arctic_flax#

class easydel.modules.arctic.modeling_arctic_flax.ArcticAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

ArcticAttention module. This module implements the attention mechanism for the Arctic model, supporting features like rotary position embeddings and flexible attention implementations.

config#

Configuration object for the Arctic model.

Type

ArcticConfig

dtype#

Data type for computation (e.g., float32). Defaults to float32.

Type

jnp.dtype

param_dtype#

Data type for parameters (e.g., float32). Defaults to float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations (e.g., None, ‘high’, ‘highest’). Defaults to None.

Type

jax.lax.PrecisionLike

rngs#

Random number generators for the module.

Type

nn.Rngs

class easydel.modules.arctic.modeling_arctic_flax.ArcticDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Arctic Decoder Layer. This module combines the ArcticAttention and ArcticMoeBlock (or ArcticMLP) with layer normalization and residual connections to form a standard Transformer decoder layer.

config#

Configuration object for the Arctic model.

Type

ArcticConfig

layer_idx#

The index of the current layer.

Type

int

dtype#

Data type for computation. Defaults to jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Defaults to jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Defaults to None.

Type

jax.lax.PrecisionLike

rngs#

Random number generators for the module.

Type

nn.Rngs

class easydel.modules.arctic.modeling_arctic_flax.ArcticForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Arctic model specifically adapted for Causal Language Modeling (CLM). This module wraps the core ArcticModel and adds a language modeling head on top.

config#

Configuration object for the Arctic model.

Type

ArcticConfig

dtype#

Data type for computation. Defaults to jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Defaults to jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Defaults to None.

Type

jax.lax.PrecisionLike

rngs#

Random number generators for the module.

Type

nn.Rngs

class easydel.modules.arctic.modeling_arctic_flax.ArcticForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Arctic model adapted for sequence classification tasks. This module wraps the core ArcticModel and adds a classification head on top.

config#

Configuration object for the Arctic model (must include num_labels).

Type

ArcticConfig

dtype#

Data type for computation. Defaults to jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Defaults to jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Defaults to None.

Type

jax.lax.PrecisionLike

rngs#

Random number generators for the module.

Type

nn.Rngs

class easydel.modules.arctic.modeling_arctic_flax.ArcticMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Arctic Multi-Layer Perceptron (MLP) block. This block implements the feed-forward network used in the Arctic model. It can optionally function as a residual MLP.

config#

Configuration object for the Arctic model.

Type

ArcticConfig

dtype#

Data type for computation. Defaults to jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Defaults to jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Defaults to None.

Type

jax.lax.PrecisionLike

is_residual_mlp#

Whether this MLP block is a residual MLP. Defaults to False.

Type

bool

rngs#

Random number generators for the module.

Type

nn.Rngs

class easydel.modules.arctic.modeling_arctic_flax.ArcticModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Core Arctic model architecture. This module implements the main Transformer stack for the Arctic model, including token embeddings and decoder layers.

config#

Configuration object for the Arctic model.

Type

ArcticConfig

dtype#

Data type for computation. Defaults to jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Defaults to jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Defaults to None.

Type

jax.lax.PrecisionLike

rngs#

Random number generators for the module.

Type

nn.Rngs

class easydel.modules.arctic.modeling_arctic_flax.ArcticMoeBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

Arctic Mixture of Experts (MoE) block. This module implements the MoE layer used in the Arctic model, routing tokens to different experts based on a gating mechanism.

config#

Configuration object for the Arctic model.

Type

ArcticConfig

layer_idx#

The index of the current layer.

Type

int

dtype#

Data type for computation. Defaults to jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Defaults to jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Defaults to None.

Type

jax.lax.PrecisionLike

rngs#

Random number generators for the module.

Type

nn.Rngs