easydel.modules.arctic.modeling_arctic_flax#
- class easydel.modules.arctic.modeling_arctic_flax.ArcticAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModuleArcticAttention module. This module implements the attention mechanism for the Arctic model, supporting features like rotary position embeddings and flexible attention implementations.
- config#
Configuration object for the Arctic model.
- Type
- dtype#
Data type for computation (e.g., float32). Defaults to float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters (e.g., float32). Defaults to float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations (e.g., None, ‘high’, ‘highest’). Defaults to None.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators for the module.
- Type
nn.Rngs
- class easydel.modules.arctic.modeling_arctic_flax.ArcticDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleArctic Decoder Layer. This module combines the ArcticAttention and ArcticMoeBlock (or ArcticMLP) with layer normalization and residual connections to form a standard Transformer decoder layer.
- config#
Configuration object for the Arctic model.
- Type
- layer_idx#
The index of the current layer.
- Type
int
- dtype#
Data type for computation. Defaults to jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Defaults to jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Defaults to None.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators for the module.
- Type
nn.Rngs
- class easydel.modules.arctic.modeling_arctic_flax.ArcticForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleArctic model specifically adapted for Causal Language Modeling (CLM). This module wraps the core ArcticModel and adds a language modeling head on top.
- config#
Configuration object for the Arctic model.
- Type
- dtype#
Data type for computation. Defaults to jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Defaults to jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Defaults to None.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators for the module.
- Type
nn.Rngs
- class easydel.modules.arctic.modeling_arctic_flax.ArcticForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleArctic model adapted for sequence classification tasks. This module wraps the core ArcticModel and adds a classification head on top.
- config#
Configuration object for the Arctic model (must include num_labels).
- Type
- dtype#
Data type for computation. Defaults to jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Defaults to jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Defaults to None.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators for the module.
- Type
nn.Rngs
- class easydel.modules.arctic.modeling_arctic_flax.ArcticMLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleArctic Multi-Layer Perceptron (MLP) block. This block implements the feed-forward network used in the Arctic model. It can optionally function as a residual MLP.
- config#
Configuration object for the Arctic model.
- Type
- dtype#
Data type for computation. Defaults to jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Defaults to jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Defaults to None.
- Type
jax.lax.PrecisionLike
- is_residual_mlp#
Whether this MLP block is a residual MLP. Defaults to False.
- Type
bool
- rngs#
Random number generators for the module.
- Type
nn.Rngs
- class easydel.modules.arctic.modeling_arctic_flax.ArcticModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleCore Arctic model architecture. This module implements the main Transformer stack for the Arctic model, including token embeddings and decoder layers.
- config#
Configuration object for the Arctic model.
- Type
- dtype#
Data type for computation. Defaults to jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Defaults to jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Defaults to None.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators for the module.
- Type
nn.Rngs
- class easydel.modules.arctic.modeling_arctic_flax.ArcticMoeBlock(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleArctic Mixture of Experts (MoE) block. This module implements the MoE layer used in the Arctic model, routing tokens to different experts based on a gating mechanism.
- config#
Configuration object for the Arctic model.
- Type
- layer_idx#
The index of the current layer.
- Type
int
- dtype#
Data type for computation. Defaults to jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Defaults to jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Defaults to None.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators for the module.
- Type
nn.Rngs