easydel.modules.clip.modeling_clip_flax#
- class easydel.modules.clip.modeling_clip_flax.CLIPAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModuleCLIP Attention module, supporting both text (causal) and vision (non-causal) attention.
- config#
Configuration object.
- Type
Union[CLIPTextConfig, CLIPVisionConfig]
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.clip.modeling_clip_flax.CLIPEncoder(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleTransformer encoder consisting of CLIPEncoderLayer layers.
- config#
Configuration object.
- Type
Union[CLIPTextConfig, CLIPVisionConfig]
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- property causal_mask#
Returns the causal mask if the encoder is for text, otherwise None.
- Returns
Causal mask.
- Return type
Optional[chex.Array]
- class easydel.modules.clip.modeling_clip_flax.CLIPEncoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleSingle CLIP encoder layer, combining self-attention and MLP.
- config#
Configuration object.
- Type
Union[CLIPTextConfig, CLIPVisionConfig]
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.clip.modeling_clip_flax.CLIPForImageClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleCLIP vision model with an image classification head on top (a linear layer on the pooled final hidden state).
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.clip.modeling_clip_flax.CLIPMLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleCLIP MLP (Feed-Forward) layer.
- config#
Configuration object.
- Type
Union[CLIPTextConfig, CLIPVisionConfig]
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.clip.modeling_clip_flax.CLIPModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- compute_loss(*, labels=None, loss_config=None, loss_kwargs=None, **batch) Tuple[Any, CLIPOutput][source]#
Computes the loss for the model given a batch of inputs and labels.
This method performs a forward pass using the provided batch arguments, then calculates the loss using the determined loss_function. It handles potential label inference (e.g., using input_ids as labels for Causal LM) and default loss configurations.
- Parameters
labels (tp.Optional[chex.Array], optional) – The target labels. If None and the task is Causal LM, input_ids from the batch might be used. Defaults to None.
loss_config (tp.Optional[LossConfig], optional) – Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.
loss_kwargs (tp.Optional[tp.Dict], optional) – Additional keyword arguments to pass directly to the loss function. Defaults to None.
**batch – Keyword arguments representing the input batch (e.g., input_ids, attention_mask).
- Returns
- A tuple containing:
The model’s output ( Pytree typically including logits, hidden states etc.)
A LossMetrics object containing the calculated loss and potentially other metrics.
- Return type
tp.Tuple[tp.Any, LossMetrics]
- Raises
AssertionError – If labels are required for the loss function but are not provided or inferred.
AssertionError – If sequence classification loss is used without num_labels in the config.
- class easydel.modules.clip.modeling_clip_flax.CLIPTextEmbeddings(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleConstructs the text embeddings for CLIP.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.clip.modeling_clip_flax.CLIPTextModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleBare CLIP text model (transformer) outputting raw hidden-states without any specific head on top.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.clip.modeling_clip_flax.CLIPTextModelWithProjection(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleCLIP text model with a projection layer on top.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.clip.modeling_clip_flax.CLIPTextTransformer(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe transformer encoder for the CLIP text model.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.clip.modeling_clip_flax.CLIPVisionEmbeddings(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleConstructs the vision embeddings for CLIP.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.clip.modeling_clip_flax.CLIPVisionModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleBare CLIP vision model (transformer) outputting raw hidden-states without any specific head on top.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.clip.modeling_clip_flax.CLIPVisionTransformer(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe transformer encoder for the CLIP vision model.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs