easydel.modules.arctic.modeling_arctic#

class easydel.modules.arctic.modeling_arctic.ArcticAttention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

Arctic Attention module with sliding window support.

Inherits from UnifiedAttention with Arctic-specific customizations: - Sliding window attention - Custom bias configuration (uses attention_bias config)

class easydel.modules.arctic.modeling_arctic.ArcticDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Arctic Decoder Layer. This module combines the ArcticAttention and ArcticMoeBlock (or ArcticMLP) with layer normalization and residual connections to form a standard Transformer decoder layer.

config#

Configuration object for the Arctic model.

Type

ArcticConfig

layer_idx#

The index of the current layer.

Type

int

dtype#

Data type for computation. Defaults to jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Defaults to jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Defaults to None.

Type

jax.lax.PrecisionLike

rngs#

Random number generators for the module.

Type

nn.Rngs

class easydel.modules.arctic.modeling_arctic.ArcticForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[ArcticModel, ArcticConfig]

Arctic model with a Causal Language Modeling head.

class easydel.modules.arctic.modeling_arctic.ArcticForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: BaseSequenceClassificationModule[ArcticModel, ArcticConfig]

Arctic model with a Sequence Classification head.

class easydel.modules.arctic.modeling_arctic.ArcticMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Arctic Multi-Layer Perceptron (MLP) block. This block implements the feed-forward network used in the Arctic model. It can optionally function as a residual MLP.

config#

Configuration object for the Arctic model.

Type

ArcticConfig

dtype#

Data type for computation. Defaults to jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Defaults to jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Defaults to None.

Type

jax.lax.PrecisionLike

is_residual_mlp#

Whether this MLP block is a residual MLP. Defaults to False.

Type

bool

rngs#

Random number generators for the module.

Type

nn.Rngs

class easydel.modules.arctic.modeling_arctic.ArcticMLPMoE(*args: Any, **kwargs: Any)[source]#

Bases: Module

Arctic Multi-Layer Perceptron (MLP) block. This block implements the feed-forward network used in the Arctic model. It can optionally function as a residual MLP.

config#

Configuration object for the Arctic model.

Type

ArcticConfig

dtype#

Data type for computation. Defaults to jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Defaults to jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Defaults to None.

Type

jax.lax.PrecisionLike

is_residual_mlp#

Whether this MLP block is a residual MLP. Defaults to False.

Type

bool

rngs#

Random number generators for the module.

Type

nn.Rngs

reform_param: ClassVar = {'down_proj$': {'inverse_spliter': <function ArcticMLPMoE.<lambda>>, 'splits': [{'name': 'w2.kernel', 'spliter': <function ArcticMLPMoE.<lambda>>}]}, 'gate_up_proj$': {'inverse_spliter': <function ArcticMLPMoE.<lambda>>, 'splits': [{'name': 'w1.kernel', 'spliter': <function ArcticMLPMoE.<lambda>>}, {'name': 'w3.kernel', 'spliter': <function ArcticMLPMoE.<lambda>>}]}}#
class easydel.modules.arctic.modeling_arctic.ArcticModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Core Arctic model architecture. This module implements the main Transformer stack for the Arctic model, including token embeddings and decoder layers.

config#

Configuration object for the Arctic model.

Type

ArcticConfig

dtype#

Data type for computation. Defaults to jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Defaults to jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Defaults to None.

Type

jax.lax.PrecisionLike

rngs#

Random number generators for the module.

Type

nn.Rngs

get_decoder() Module[source]#

Returns the decoder part of the model’s graph definition. For ArcticModel, this is the model itself.

get_embedding() Module[source]#

Returns the embedding layer of the module.

get_encoder() Module[source]#

Returns the encoder part of the model’s graph definition. For ArcticModel (decoder-only), this is not applicable.

get_lm_head() Module[source]#

Returns the language model head of the module. ArcticModel does not include the lm_head.

class easydel.modules.arctic.modeling_arctic.ArcticMoeBlock(*args: Any, **kwargs: Any)[source]#

Bases: BaseMoeModule

Arctic Mixture of Experts (MoE) block. This module implements the MoE layer used in the Arctic model, routing tokens to different experts based on a gating mechanism.

config#

Configuration object for the Arctic model.

Type

ArcticConfig

layer_idx#

The index of the current layer.

Type

int

dtype#

Data type for computation. Defaults to jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Defaults to jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Defaults to None.

Type

jax.lax.PrecisionLike

rngs#

Random number generators for the module.

Type

nn.Rngs